diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..fdd2777c7
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,4 @@
+# .git-blame-ignore-revs
+# Re-formatted entire code base with black
+7ebf0753485c931db4135953dcd0864b4d089ed5
+
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index a9c6d249a..6c50ace97 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -3,6 +3,7 @@
### Checklist
- [ ] Wrote a description of my changes above
+- [ ] Formatted my code with [`black`](https://black.readthedocs.io/en/stable/index.html)
- [ ] Added a bullet point for my changes to the top of the `CHANGELOG.md` file
- [ ] Added or modified unit tests to reflect my changes
- [ ] Manually tested with a notebook
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 000000000..b04fb15cb
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,10 @@
+name: Lint
+
+on: [push, pull_request]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: psf/black@stable
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d51ace3e8..86b755c27 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,8 @@
### Updates
* Officially drop support for Python 3.6. Sparkmagic will not guarantee Python 3.6 compatibility moving forward.
+* Re-format all code with [`black`](https://black.readthedocs.io/en/stable/index.html) and validate via CI
+
### Bug Fixes
diff --git a/README.md b/README.md
index 6e01367f6..f059e26da 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,5 @@
-[![Build Status](https://travis-ci.org/jupyter-incubator/sparkmagic.svg?branch=master)](https://travis-ci.org/jupyter-incubator/sparkmagic) [![Join the chat at https://gitter.im/sparkmagic/Lobby](https://badges.gitter.im/sparkmagic/Lobby.svg)](https://gitter.im/sparkmagic/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
+[![Build Status](https://travis-ci.org/jupyter-incubator/sparkmagic.svg?branch=master)](https://travis-ci.org/jupyter-incubator/sparkmagic) [![Join the chat at https://gitter.im/sparkmagic/Lobby](https://badges.gitter.im/sparkmagic/Lobby.svg)](https://gitter.im/sparkmagic/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
+
# sparkmagic
diff --git a/autovizwidget/autovizwidget/plotlygraphs/bargraph.py b/autovizwidget/autovizwidget/plotlygraphs/bargraph.py
index 61e671beb..4de275e4e 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/bargraph.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/bargraph.py
@@ -10,4 +10,3 @@ class BarGraph(GraphBase):
def _get_data(self, df, encoding):
x_values, y_values = GraphBase._get_x_y_values(df, encoding)
return [Bar(x=x_values, y=y_values)]
-
diff --git a/autovizwidget/autovizwidget/plotlygraphs/datagraph.py b/autovizwidget/autovizwidget/plotlygraphs/datagraph.py
index 19021a23d..358413d07 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/datagraph.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/datagraph.py
@@ -8,6 +8,7 @@
class DataGraph(object):
"""This does not use the table version of plotly because it freezes up the browser for >60 rows. Instead, we use
pandas df HTML representation."""
+
def __init__(self, display=None):
if display is None:
self.display = IpythonDisplay()
@@ -21,7 +22,8 @@ def render(self, df, encoding, output):
show_dimensions = pd.get_option("display.show_dimensions")
# This will hide the index column for pandas df.
- self.display.html("""
+ self.display.html(
+ """
-""")
- self.display.html(df.to_html(max_rows=max_rows, max_cols=max_cols,
- show_dimensions=show_dimensions, notebook=True, classes="hideme"))
+"""
+ )
+ self.display.html(
+ df.to_html(
+ max_rows=max_rows,
+ max_cols=max_cols,
+ show_dimensions=show_dimensions,
+ notebook=True,
+ classes="hideme",
+ )
+ )
@staticmethod
def display_logarithmic_x_axis():
diff --git a/autovizwidget/autovizwidget/plotlygraphs/graphbase.py b/autovizwidget/autovizwidget/plotlygraphs/graphbase.py
index 9b46f906c..7e09fe8b7 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/graphbase.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/graphbase.py
@@ -3,6 +3,7 @@
from plotly.graph_objs import Figure, Layout
from plotly.offline import iplot
+
try:
from pandas.core.base import DataError
except:
@@ -29,16 +30,20 @@ def render(self, df, encoding, output):
type_x_axis = self._get_type_axis(encoding.logarithmic_x_axis)
type_y_axis = self._get_type_axis(encoding.logarithmic_y_axis)
- layout = Layout(xaxis=dict(type=type_x_axis, rangemode="tozero", title=encoding.x),
- yaxis=dict(type=type_y_axis, rangemode="tozero", title=encoding.y))
+ layout = Layout(
+ xaxis=dict(type=type_x_axis, rangemode="tozero", title=encoding.x),
+ yaxis=dict(type=type_y_axis, rangemode="tozero", title=encoding.y),
+ )
with output:
try:
fig = Figure(data=data, layout=layout)
iplot(fig, show_link=False)
except TypeError:
- print("\n\n\nPlease select another set of X and Y axis, because the type of the current axis do\n"
- "not support aggregation over it.")
+ print(
+ "\n\n\nPlease select another set of X and Y axis, because the type of the current axis do\n"
+ "not support aggregation over it."
+ )
@staticmethod
def display_x():
@@ -68,10 +73,9 @@ def _get_data(self, df, encoding):
@staticmethod
def _get_x_y_values(df, encoding):
try:
- x_values, y_values = GraphBase._get_x_y_values_aggregated(df,
- encoding.x,
- encoding.y,
- encoding.y_aggregation)
+ x_values, y_values = GraphBase._get_x_y_values_aggregated(
+ df, encoding.x, encoding.y, encoding.y_aggregation
+ )
except ValueError:
x_values = GraphBase._get_x_values(df, encoding)
y_values = GraphBase._get_y_values(df, encoding)
@@ -99,8 +103,11 @@ def _get_x_y_values_aggregated(df, x_column, y_column, y_aggregation):
try:
df_grouped = df.groupby(x_column)
except TypeError:
- raise InvalidEncodingError("Cannot group by X column '{}' because of its type: '{}'."
- .format(df[x_column].dtype))
+ raise InvalidEncodingError(
+ "Cannot group by X column '{}' because of its type: '{}'.".format(
+ df[x_column].dtype
+ )
+ )
else:
try:
if y_aggregation == Encoding.y_agg_avg:
@@ -114,20 +121,30 @@ def _get_x_y_values_aggregated(df, x_column, y_column, y_aggregation):
elif y_aggregation == Encoding.y_agg_count:
df_transformed = df_grouped.count()
else:
- raise ValueError("Y aggregation '{}' not supported.".format(y_aggregation))
+ raise ValueError(
+ "Y aggregation '{}' not supported.".format(y_aggregation)
+ )
except (DataError, ValueError) as err:
- raise InvalidEncodingError("Cannot aggregate column '{}' with aggregation function '{}' because:\n\t'{}'."
- .format(y_column, y_aggregation, err))
+ raise InvalidEncodingError(
+ "Cannot aggregate column '{}' with aggregation function '{}' because:\n\t'{}'.".format(
+ y_column, y_aggregation, err
+ )
+ )
except TypeError:
- raise InvalidEncodingError("Cannot aggregate column '{}' with aggregation function '{}' because the type\n"
- "cannot be aggregated over."
- .format(y_column, y_aggregation))
+ raise InvalidEncodingError(
+ "Cannot aggregate column '{}' with aggregation function '{}' because the type\n"
+ "cannot be aggregated over.".format(y_column, y_aggregation)
+ )
else:
df_transformed = df_transformed.reset_index()
if y_column not in df_transformed.columns:
- raise InvalidEncodingError("Y column '{}' is not valid with aggregation function '{}'. Please select "
- "a different\naggregation function.".format(y_column, y_aggregation))
+ raise InvalidEncodingError(
+ "Y column '{}' is not valid with aggregation function '{}'. Please select "
+ "a different\naggregation function.".format(
+ y_column, y_aggregation
+ )
+ )
x_values = df_transformed[x_column].tolist()
y_values = df_transformed[y_column].tolist()
diff --git a/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py b/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py
index 82dbf8868..9d6ff1755 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py
@@ -14,7 +14,6 @@
class GraphRenderer(object):
-
@staticmethod
def render(df, encoding, output):
with output:
diff --git a/autovizwidget/autovizwidget/plotlygraphs/linegraph.py b/autovizwidget/autovizwidget/plotlygraphs/linegraph.py
index 35356bd8d..e3612f7b2 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/linegraph.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/linegraph.py
@@ -7,7 +7,6 @@
class LineGraph(GraphBase):
-
def _get_data(self, df, encoding):
x_values, y_values = GraphBase._get_x_y_values(df, encoding)
return [Scatter(x=x_values, y=y_values)]
diff --git a/autovizwidget/autovizwidget/plotlygraphs/piegraph.py b/autovizwidget/autovizwidget/plotlygraphs/piegraph.py
index 10b1a3fb5..d383c924f 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/piegraph.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/piegraph.py
@@ -3,11 +3,12 @@
from plotly.graph_objs import Pie, Figure
from plotly.offline import iplot
+
try:
from pandas.core.base import DataError
except:
from pandas.core.groupby import DataError
-
+
import autovizwidget.utils.configuration as conf
from .graphbase import GraphBase
@@ -24,13 +25,19 @@ def render(df, encoding, output):
values, labels = PieGraph._get_x_values_labels(df, encoding)
except TypeError:
with output:
- print("\n\n\nCannot group by X selection because of its type: '{}'. Please select another column."
- .format(df[encoding.x].dtype))
+ print(
+ "\n\n\nCannot group by X selection because of its type: '{}'. Please select another column.".format(
+ df[encoding.x].dtype
+ )
+ )
return
except (ValueError, DataError):
with output:
- print("\n\n\nCannot group by X selection. Please select another column."
- .format(df[encoding.x].dtype))
+ print(
+ "\n\n\nCannot group by X selection. Please select another column.".format(
+ df[encoding.x].dtype
+ )
+ )
if df.size == 0:
print("\n\n\nCannot display a pie graph for an empty data set.")
return
@@ -42,9 +49,12 @@ def render(df, encoding, output):
# 500 rows take ~15 s.
# 100 rows is almost automatic.
if len(values) > max_slices_pie_graph:
- print("There's {} values in your pie graph, which would render the graph unresponsive.\n"
- "Please select another X with at most {} possible values."
- .format(len(values), max_slices_pie_graph))
+ print(
+ "There's {} values in your pie graph, which would render the graph unresponsive.\n"
+ "Please select another X with at most {} possible values.".format(
+ len(values), max_slices_pie_graph
+ )
+ )
else:
data = [Pie(values=values, labels=labels)]
diff --git a/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py b/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py
index 92c88f175..ebf938886 100644
--- a/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py
+++ b/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py
@@ -4,7 +4,6 @@
class ScatterGraph(GraphBase):
-
def _get_data(self, df, encoding):
x_values, y_values = GraphBase._get_x_y_values(df, encoding)
- return [Scatter(x=x_values, y=y_values, mode='markers')]
+ return [Scatter(x=x_values, y=y_values, mode="markers")]
diff --git a/autovizwidget/autovizwidget/tests/test_autovizwidget.py b/autovizwidget/autovizwidget/tests/test_autovizwidget.py
index 1f7b7a555..ee648b970 100644
--- a/autovizwidget/autovizwidget/tests/test_autovizwidget.py
+++ b/autovizwidget/autovizwidget/tests/test_autovizwidget.py
@@ -27,12 +27,14 @@ def _setup():
renderer.display_logarithmic_x_axis.return_value = True
renderer.display_logarithmic_y_axis.return_value = True
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": 12},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": 0},
+ {"buildingID": 2, "date": "6/1/14", "temp_diff": 11},
+ {"buildingID": 0, "date": "6/1/15", "temp_diff": 5},
+ {"buildingID": 1, "date": "6/1/16", "temp_diff": 19},
+ {"buildingID": 2, "date": "6/1/17", "temp_diff": 32},
+ ]
df = pd.DataFrame(records)
encoding = Encoding(chart_type="table", x="date", y="temp_diff")
@@ -53,8 +55,16 @@ def _teardown():
@with_setup(_setup, _teardown)
def test_on_render_viz():
- widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory,
- encoding_widget, ipython_display, spark_events=spark_events, testing=True)
+ widget = AutoVizWidget(
+ df,
+ encoding,
+ renderer,
+ ipywidget_factory,
+ encoding_widget,
+ ipython_display,
+ spark_events=spark_events,
+ testing=True,
+ )
# on_render_viz is called in the constructor, so no need to call it here.
output.clear_output.assert_called_once_with()
@@ -76,59 +86,129 @@ def test_on_render_viz():
encoding._chart_type = Encoding.chart_type_scatter
widget.on_render_viz()
assert_equals(len(spark_events.emit_graph_render_event.mock_calls), 2)
- assert_equals(spark_events.emit_graph_render_event.call_args, call(Encoding.chart_type_scatter))
+ assert_equals(
+ spark_events.emit_graph_render_event.call_args,
+ call(Encoding.chart_type_scatter),
+ )
@with_setup(_setup, _teardown)
def test_create_viz_types_buttons():
- df_single_column = pd.DataFrame([{u'buildingID': 0}])
- widget = AutoVizWidget(df_single_column, encoding, renderer, ipywidget_factory,
- encoding_widget, ipython_display, spark_events=spark_events, testing=True)
+ df_single_column = pd.DataFrame([{"buildingID": 0}])
+ widget = AutoVizWidget(
+ df_single_column,
+ encoding,
+ renderer,
+ ipywidget_factory,
+ encoding_widget,
+ ipython_display,
+ spark_events=spark_events,
+ testing=True,
+ )
# create_viz_types_buttons is called in the constructor, so no need to call it here.
- assert call(description=Encoding.chart_type_table) in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_pie) in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_line) not in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_area) not in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_bar) not in ipywidget_factory.get_button.mock_calls
+ assert (
+ call(description=Encoding.chart_type_table)
+ in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_pie)
+ in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_line)
+ not in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_area)
+ not in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_bar)
+ not in ipywidget_factory.get_button.mock_calls
+ )
spark_events.emit_graph_render_event.assert_called_once_with(encoding.chart_type)
- widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory,
- encoding_widget, ipython_display, spark_events=spark_events, testing=True)
+ widget = AutoVizWidget(
+ df,
+ encoding,
+ renderer,
+ ipywidget_factory,
+ encoding_widget,
+ ipython_display,
+ spark_events=spark_events,
+ testing=True,
+ )
# create_viz_types_buttons is called in the constructor, so no need to call it here.
- assert call(description=Encoding.chart_type_table) in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_pie) in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_line) in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_area) in ipywidget_factory.get_button.mock_calls
- assert call(description=Encoding.chart_type_bar) in ipywidget_factory.get_button.mock_calls
+ assert (
+ call(description=Encoding.chart_type_table)
+ in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_pie)
+ in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_line)
+ in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_area)
+ in ipywidget_factory.get_button.mock_calls
+ )
+ assert (
+ call(description=Encoding.chart_type_bar)
+ in ipywidget_factory.get_button.mock_calls
+ )
@with_setup(_setup, _teardown)
def test_create_viz_empty_df():
df = pd.DataFrame([])
- widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory,
- encoding_widget, ipython_display, spark_events=spark_events, testing=True)
+ widget = AutoVizWidget(
+ df,
+ encoding,
+ renderer,
+ ipywidget_factory,
+ encoding_widget,
+ ipython_display,
+ spark_events=spark_events,
+ testing=True,
+ )
ipywidget_factory.get_button.assert_not_called()
ipywidget_factory.get_html.assert_called_once_with("No results.")
ipython_display.display.assert_called_with(ipywidget_factory.get_html.return_value)
spark_events.emit_graph_render_event.assert_called_once_with(encoding.chart_type)
+
@with_setup(_setup, _teardown)
def test_convert_to_displayable_dataframe():
- bool_df = pd.DataFrame([{u'bool_col': True, u'int_col': 0, u'float_col': 3.0},
- {u'bool_col': False, u'int_col': 100, u'float_col': 0.7}])
+ bool_df = pd.DataFrame(
+ [
+ {"bool_col": True, "int_col": 0, "float_col": 3.0},
+ {"bool_col": False, "int_col": 100, "float_col": 0.7},
+ ]
+ )
copy_of_df = bool_df.copy()
- widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory,
- encoding_widget, ipython_display, spark_events=spark_events, testing=True)
+ widget = AutoVizWidget(
+ df,
+ encoding,
+ renderer,
+ ipywidget_factory,
+ encoding_widget,
+ ipython_display,
+ spark_events=spark_events,
+ testing=True,
+ )
result = AutoVizWidget._convert_to_displayable_dataframe(bool_df)
# Ensure original DF not changed
assert_frame_equal(bool_df, copy_of_df)
- assert_series_equal(bool_df[u'int_col'], result[u'int_col'])
- assert_series_equal(bool_df[u'float_col'], result[u'float_col'])
- assert_equals(result.dtypes[u'bool_col'], object)
- assert_equals(len(result[u'bool_col']), 2)
- assert_equals(result[u'bool_col'][0], 'True')
- assert_equals(result[u'bool_col'][1], 'False')
+ assert_series_equal(bool_df["int_col"], result["int_col"])
+ assert_series_equal(bool_df["float_col"], result["float_col"])
+ assert_equals(result.dtypes["bool_col"], object)
+ assert_equals(len(result["bool_col"]), 2)
+ assert_equals(result["bool_col"][0], "True")
+ assert_equals(result["bool_col"][1], "False")
spark_events.emit_graph_render_event.assert_called_once_with(encoding.chart_type)
diff --git a/autovizwidget/autovizwidget/tests/test_encodingwidget.py b/autovizwidget/autovizwidget/tests/test_encodingwidget.py
index 622a5add7..3153375ec 100644
--- a/autovizwidget/autovizwidget/tests/test_encodingwidget.py
+++ b/autovizwidget/autovizwidget/tests/test_encodingwidget.py
@@ -16,12 +16,14 @@
def _setup():
global df, encoding, ipywidget_factory, change_hook
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12, u'\u263A': True},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0, u'\u263A': True},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11, u'\u263A': True},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5, u'\u263A': True},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19, u'\u263A': True},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32, u'\u263A': True}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": 12, "\u263A": True},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": 0, "\u263A": True},
+ {"buildingID": 2, "date": "6/1/14", "temp_diff": 11, "\u263A": True},
+ {"buildingID": 0, "date": "6/1/15", "temp_diff": 5, "\u263A": True},
+ {"buildingID": 1, "date": "6/1/16", "temp_diff": 19, "\u263A": True},
+ {"buildingID": 2, "date": "6/1/17", "temp_diff": 32, "\u263A": True},
+ ]
df = pd.DataFrame(records)
encoding = Encoding(chart_type="table", x="date", y="temp_diff")
@@ -38,12 +40,14 @@ def _teardown():
@with_setup(_setup, _teardown)
def test_encoding_with_all_none_doesnt_throw():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": 12},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": 0},
+ {"buildingID": 2, "date": "6/1/14", "temp_diff": 11},
+ {"buildingID": 0, "date": "6/1/15", "temp_diff": 5},
+ {"buildingID": 1, "date": "6/1/16", "temp_diff": 19},
+ {"buildingID": 2, "date": "6/1/17", "temp_diff": 32},
+ ]
df = pd.DataFrame(records)
encoding = Encoding()
@@ -53,16 +57,47 @@ def test_encoding_with_all_none_doesnt_throw():
EncodingWidget(df, encoding, change_hook, ipywidget_factory, testing=True)
- assert call(description='X', value=None, options={'date': 'date', 'temp_diff': 'temp_diff', '-': None,
- 'buildingID': 'buildingID'}) \
+ assert (
+ call(
+ description="X",
+ value=None,
+ options={
+ "date": "date",
+ "temp_diff": "temp_diff",
+ "-": None,
+ "buildingID": "buildingID",
+ },
+ )
in ipywidget_factory.get_dropdown.mock_calls
- assert call(description='Y', value=None, options={'date': 'date', 'temp_diff': 'temp_diff', '-': None,
- 'buildingID': 'buildingID'}) \
+ )
+ assert (
+ call(
+ description="Y",
+ value=None,
+ options={
+ "date": "date",
+ "temp_diff": "temp_diff",
+ "-": None,
+ "buildingID": "buildingID",
+ },
+ )
in ipywidget_factory.get_dropdown.mock_calls
- assert call(description='Func.', value='none', options={'Max': 'Max', 'Sum': 'Sum', 'Avg': 'Avg',
- '-': 'None', 'Min': 'Min', 'Count': 'Count'}) \
+ )
+ assert (
+ call(
+ description="Func.",
+ value="none",
+ options={
+ "Max": "Max",
+ "Sum": "Sum",
+ "Avg": "Avg",
+ "-": "None",
+ "Min": "Min",
+ "Count": "Count",
+ },
+ )
in ipywidget_factory.get_dropdown.mock_calls
-
+ )
@with_setup(_setup, _teardown)
diff --git a/autovizwidget/autovizwidget/tests/test_plotlygraphs.py b/autovizwidget/autovizwidget/tests/test_plotlygraphs.py
index 427ed1b45..5b4672660 100644
--- a/autovizwidget/autovizwidget/tests/test_plotlygraphs.py
+++ b/autovizwidget/autovizwidget/tests/test_plotlygraphs.py
@@ -16,51 +16,86 @@ def test_graph_base_display_methods():
def test_graphbase_get_x_y_values():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12, u"str": "str"},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0, u"str": "str"},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11, u"str": "str"},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5, u"str": "str"},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19, u"str": "str"},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32, u"str": "str"}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": 12, "str": "str"},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": 0, "str": "str"},
+ {"buildingID": 2, "date": "6/1/14", "temp_diff": 11, "str": "str"},
+ {"buildingID": 0, "date": "6/1/15", "temp_diff": 5, "str": "str"},
+ {"buildingID": 1, "date": "6/1/16", "temp_diff": 19, "str": "str"},
+ {"buildingID": 2, "date": "6/1/17", "temp_diff": 32, "str": "str"},
+ ]
df = pd.DataFrame(records)
- expected_xs = [u'6/1/13', u'6/1/14', u'6/1/15', u'6/1/16', u'6/1/17']
-
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_sum)
+ expected_xs = ["6/1/13", "6/1/14", "6/1/15", "6/1/16", "6/1/17"]
+
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_sum,
+ )
xs, yx = GraphBase._get_x_y_values(df, encoding)
assert xs == expected_xs
assert yx == [12, 11, 5, 19, 32]
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_avg)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_avg,
+ )
xs, yx = GraphBase._get_x_y_values(df, encoding)
assert xs == expected_xs
assert yx == [6, 11, 5, 19, 32]
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_max)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_max,
+ )
xs, yx = GraphBase._get_x_y_values(df, encoding)
assert xs == expected_xs
assert yx == [12, 11, 5, 19, 32]
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_min)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_min,
+ )
xs, yx = GraphBase._get_x_y_values(df, encoding)
assert xs == expected_xs
assert yx == [0, 11, 5, 19, 32]
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_none)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_none,
+ )
xs, yx = GraphBase._get_x_y_values(df, encoding)
- assert xs == [u'6/1/13', u'6/1/13', u'6/1/14', u'6/1/15', u'6/1/16', u'6/1/17']
+ assert xs == ["6/1/13", "6/1/13", "6/1/14", "6/1/15", "6/1/16", "6/1/17"]
assert yx == [12, 0, 11, 5, 19, 32]
try:
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="buildingID", y="date",
- y_aggregation=Encoding.y_agg_avg)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="buildingID",
+ y="date",
+ y_aggregation=Encoding.y_agg_avg,
+ )
GraphBase._get_x_y_values(df, encoding)
assert False
except InvalidEncodingError:
pass
try:
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="str",
- y_aggregation=Encoding.y_agg_avg)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="str",
+ y_aggregation=Encoding.y_agg_avg,
+ )
GraphBase._get_x_y_values(df, encoding)
assert False
except InvalidEncodingError:
@@ -75,21 +110,33 @@ def test_pie_graph_display_methods():
def test_pie_graph_get_values_labels():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": 12},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": 0},
+ {"buildingID": 2, "date": "6/1/14", "temp_diff": 11},
+ {"buildingID": 0, "date": "6/1/15", "temp_diff": 5},
+ {"buildingID": 1, "date": "6/1/16", "temp_diff": 19},
+ {"buildingID": 2, "date": "6/1/17", "temp_diff": 32},
+ ]
df = pd.DataFrame(records)
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y=None, y_aggregation=Encoding.y_agg_sum)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y=None,
+ y_aggregation=Encoding.y_agg_sum,
+ )
values, labels = PieGraph._get_x_values_labels(df, encoding)
assert values == [2, 1, 1, 1, 1]
assert labels == ["6/1/13", "6/1/14", "6/1/15", "6/1/16", "6/1/17"]
-
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_sum)
+
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_sum,
+ )
values, labels = PieGraph._get_x_values_labels(df, encoding)
@@ -98,14 +145,21 @@ def test_pie_graph_get_values_labels():
def test_data_graph_render():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": 12},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": 0},
+ {"buildingID": 2, "date": "6/1/14", "temp_diff": 11},
+ {"buildingID": 0, "date": "6/1/15", "temp_diff": 5},
+ {"buildingID": 1, "date": "6/1/16", "temp_diff": 19},
+ {"buildingID": 2, "date": "6/1/17", "temp_diff": 32},
+ ]
df = pd.DataFrame(records)
- encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_sum)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_line,
+ x="date",
+ y="temp_diff",
+ y_aggregation=Encoding.y_agg_sum,
+ )
display = MagicMock()
data = DataGraph(display)
diff --git a/autovizwidget/autovizwidget/tests/test_sparkevents.py b/autovizwidget/autovizwidget/tests/test_sparkevents.py
index 33493901c..73de4bb01 100644
--- a/autovizwidget/autovizwidget/tests/test_sparkevents.py
+++ b/autovizwidget/autovizwidget/tests/test_sparkevents.py
@@ -24,29 +24,33 @@ def _teardown():
@with_setup(_setup, _teardown)
def test_not_emit_graph_render_event_when_not_registered():
event_name = GRAPH_RENDER_EVENT
- graph_type = 'Bar'
+ graph_type = "Bar"
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (GRAPH_TYPE, graph_type)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (GRAPH_TYPE, graph_type),
+ ]
events.emit_graph_render_event(graph_type)
events.get_utc_date_time.assert_called_with()
assert not events.handler.handle_event.called
-
-
+
+
@with_setup(_setup, _teardown)
def test_emit_graph_render_event_when_registered():
conf.override(conf.events_handler.__name__, events.handler)
event_name = GRAPH_RENDER_EVENT
- graph_type = 'Bar'
-
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (GRAPH_TYPE, graph_type)]
+ graph_type = "Bar"
+
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (GRAPH_TYPE, graph_type),
+ ]
events.emit_graph_render_event(graph_type)
diff --git a/autovizwidget/autovizwidget/tests/test_utils.py b/autovizwidget/autovizwidget/tests/test_utils.py
index a26f98e1f..9abb29ad1 100644
--- a/autovizwidget/autovizwidget/tests/test_utils.py
+++ b/autovizwidget/autovizwidget/tests/test_utils.py
@@ -12,12 +12,50 @@
def _setup():
global df, encoding
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12, "mystr": "alejandro", "mystr2": "1"},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0, "mystr": "alejandro", "mystr2": "1"},
- {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11, "mystr": "alejandro", "mystr2": "1"},
- {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5, "mystr": "alejandro", "mystr2": "1.0"},
- {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19, "mystr": "alejandro", "mystr2": "1"},
- {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32, "mystr": "alejandro", "mystr2": "1"}]
+ records = [
+ {
+ "buildingID": 0,
+ "date": "6/1/13",
+ "temp_diff": 12,
+ "mystr": "alejandro",
+ "mystr2": "1",
+ },
+ {
+ "buildingID": 1,
+ "date": "6/1/13",
+ "temp_diff": 0,
+ "mystr": "alejandro",
+ "mystr2": "1",
+ },
+ {
+ "buildingID": 2,
+ "date": "6/1/14",
+ "temp_diff": 11,
+ "mystr": "alejandro",
+ "mystr2": "1",
+ },
+ {
+ "buildingID": 0,
+ "date": "6/1/15",
+ "temp_diff": 5,
+ "mystr": "alejandro",
+ "mystr2": "1.0",
+ },
+ {
+ "buildingID": 1,
+ "date": "6/1/16",
+ "temp_diff": 19,
+ "mystr": "alejandro",
+ "mystr2": "1",
+ },
+ {
+ "buildingID": 2,
+ "date": "6/1/17",
+ "temp_diff": 32,
+ "mystr": "alejandro",
+ "mystr2": "1",
+ },
+ ]
df = pd.DataFrame(records)
encoding = Encoding(chart_type="table", x="date", y="temp_diff")
@@ -46,24 +84,27 @@ def _check(d, expected):
x = utils.select_x(d)
assert x == expected
- data = dict(col1=[1.0, 2.0, 3.0], # Q
- col2=['A', 'B', 'C'], # N
- col3=pd.date_range('2012', periods=3, freq='A')) # T
- _check(data, 'col3')
+ data = dict(
+ col1=[1.0, 2.0, 3.0], # Q
+ col2=["A", "B", "C"], # N
+ col3=pd.date_range("2012", periods=3, freq="A"),
+ ) # T
+ _check(data, "col3")
- data = dict(col1=[1.0, 2.0, 3.0], # Q
- col2=['A', 'B', 'C']) # N
- _check(data, 'col2')
+ data = dict(col1=[1.0, 2.0, 3.0], col2=["A", "B", "C"]) # Q # N
+ _check(data, "col2")
data = dict(col1=[1.0, 2.0, 3.0]) # Q
- _check(data, 'col1')
+ _check(data, "col1")
# Custom order
- data = dict(col1=[1.0, 2.0, 3.0], # Q
- col2=['A', 'B', 'C'], # N
- col3=pd.date_range('2012', periods=3, freq='A'), # T
- col4=pd.date_range('2012', periods=3, freq='A')) # T
- selected_x = utils.select_x(data, ['N', 'T', 'Q', 'O'])
+ data = dict(
+ col1=[1.0, 2.0, 3.0], # Q
+ col2=["A", "B", "C"], # N
+ col3=pd.date_range("2012", periods=3, freq="A"), # T
+ col4=pd.date_range("2012", periods=3, freq="A"),
+ ) # T
+ selected_x = utils.select_x(data, ["N", "T", "Q", "O"])
assert selected_x == "col2"
# Len < 1
@@ -72,25 +113,31 @@ def _check(d, expected):
def test_select_y():
def _check(d, expected):
- x = 'col1'
+ x = "col1"
y = utils.select_y(d, x)
assert y == expected
- data = dict(col1=[1.0, 2.0, 3.0], # Chosen X
- col2=['A', 'B', 'C'], # N
- col3=pd.date_range('2012', periods=3, freq='A'), # T
- col4=pd.date_range('2012', periods=3, freq='A'), # T
- col5=[1.0, 2.0, 3.0]) # Q
- _check(data, 'col5')
-
- data = dict(col1=[1.0, 2.0, 3.0], # Chosen X
- col2=['A', 'B', 'C'], # N
- col3=pd.date_range('2012', periods=3, freq='A')) # T
- _check(data, 'col2')
-
- data = dict(col1=[1.0, 2.0, 3.0], # Chosen X
- col2=pd.date_range('2012', periods=3, freq='A')) # T
- _check(data, 'col2')
+ data = dict(
+ col1=[1.0, 2.0, 3.0], # Chosen X
+ col2=["A", "B", "C"], # N
+ col3=pd.date_range("2012", periods=3, freq="A"), # T
+ col4=pd.date_range("2012", periods=3, freq="A"), # T
+ col5=[1.0, 2.0, 3.0],
+ ) # Q
+ _check(data, "col5")
+
+ data = dict(
+ col1=[1.0, 2.0, 3.0], # Chosen X
+ col2=["A", "B", "C"], # N
+ col3=pd.date_range("2012", periods=3, freq="A"),
+ ) # T
+ _check(data, "col2")
+
+ data = dict(
+ col1=[1.0, 2.0, 3.0], # Chosen X
+ col2=pd.date_range("2012", periods=3, freq="A"),
+ ) # T
+ _check(data, "col2")
# No data
assert utils.select_y(None, "something") is None
@@ -102,12 +149,14 @@ def _check(d, expected):
assert utils.select_y(df, None) is None
# Custom order
- data = dict(col1=[1.0, 2.0, 3.0], # Chosen X
- col2=['A', 'B', 'C'], # N
- col3=pd.date_range('2012', periods=3, freq='A'), # T
- col4=pd.date_range('2012', periods=3, freq='A'), # T
- col5=[1.0, 2.0, 3.0], # Q
- col6=[1.0, 2.0, 3.0]) # Q
- selected_x = 'col1'
- selected_y = utils.select_y(data, selected_x, ['N', 'T', 'Q', 'O'])
- assert selected_y == 'col2'
+ data = dict(
+ col1=[1.0, 2.0, 3.0], # Chosen X
+ col2=["A", "B", "C"], # N
+ col3=pd.date_range("2012", periods=3, freq="A"), # T
+ col4=pd.date_range("2012", periods=3, freq="A"), # T
+ col5=[1.0, 2.0, 3.0], # Q
+ col6=[1.0, 2.0, 3.0],
+ ) # Q
+ selected_x = "col1"
+ selected_y = utils.select_y(data, selected_x, ["N", "T", "Q", "O"])
+ assert selected_y == "col2"
diff --git a/autovizwidget/autovizwidget/utils/configuration.py b/autovizwidget/autovizwidget/utils/configuration.py
index 1cd048525..7e58f28f5 100644
--- a/autovizwidget/autovizwidget/utils/configuration.py
+++ b/autovizwidget/autovizwidget/utils/configuration.py
@@ -1,12 +1,15 @@
# Distributed under the terms of the Modified BSD License.
-from hdijupyterutils.constants import EVENTS_HANDLER_CLASS_NAME, LOGGING_CONFIG_CLASS_NAME
+from hdijupyterutils.constants import (
+ EVENTS_HANDLER_CLASS_NAME,
+ LOGGING_CONFIG_CLASS_NAME,
+)
from hdijupyterutils.utils import join_paths
from hdijupyterutils.configuration import override as _override
from hdijupyterutils.configuration import override_all as _override_all
from hdijupyterutils.configuration import with_override
from .constants import HOME_PATH, CONFIG_FILE
-
+
d = {}
path = join_paths(HOME_PATH, CONFIG_FILE)
@@ -18,15 +21,18 @@ def override(config, value):
def override_all(obj):
_override_all(d, obj)
-
+
+
_with_override = with_override(d, path)
-
+
# Configs
+
@_with_override
def events_handler():
return None
+
@_with_override
def max_slices_pie_graph():
return 100
diff --git a/autovizwidget/autovizwidget/utils/events.py b/autovizwidget/autovizwidget/utils/events.py
index f3415ab90..87ee3a79a 100644
--- a/autovizwidget/autovizwidget/utils/events.py
+++ b/autovizwidget/autovizwidget/utils/events.py
@@ -17,9 +17,11 @@ def emit_graph_render_event(self, graph_type):
event_name = GRAPH_RENDER_EVENT
time_stamp = self.get_utc_date_time()
- kwargs_list = [(EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (GRAPH_TYPE, graph_type)]
-
- if self.emit:
+ kwargs_list = [
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (GRAPH_TYPE, graph_type),
+ ]
+
+ if self.emit:
self.send_to_handler(kwargs_list)
diff --git a/autovizwidget/autovizwidget/widget/autovizwidget.py b/autovizwidget/autovizwidget/widget/autovizwidget.py
index 280785e1e..0c70c3fec 100644
--- a/autovizwidget/autovizwidget/widget/autovizwidget.py
+++ b/autovizwidget/autovizwidget/widget/autovizwidget.py
@@ -13,13 +13,24 @@
class AutoVizWidget(Box):
- def __init__(self, df, encoding, renderer=None, ipywidget_factory=None, encoding_widget=None, ipython_display=None,
- nested_widget_mode=False, spark_events=None, testing=False, **kwargs):
+ def __init__(
+ self,
+ df,
+ encoding,
+ renderer=None,
+ ipywidget_factory=None,
+ encoding_widget=None,
+ ipython_display=None,
+ nested_widget_mode=False,
+ spark_events=None,
+ testing=False,
+ **kwargs
+ ):
assert encoding is not None
assert df is not None
assert type(df) is pd.DataFrame
- kwargs['orientation'] = 'vertical'
+ kwargs["orientation"] = "vertical"
if not testing:
super(AutoVizWidget, self).__init__((), **kwargs)
@@ -74,14 +85,22 @@ def on_render_viz(self, *args):
self.encoding_widget.show_x(self.renderer.display_x(self.encoding.chart_type))
self.encoding_widget.show_y(self.renderer.display_y(self.encoding.chart_type))
- self.encoding_widget.show_controls(self.renderer.display_controls(self.encoding.chart_type))
- self.encoding_widget.show_logarithmic_x_axis(self.renderer.display_logarithmic_x_axis(self.encoding.chart_type))
- self.encoding_widget.show_logarithmic_y_axis(self.renderer.display_logarithmic_y_axis(self.encoding.chart_type))
+ self.encoding_widget.show_controls(
+ self.renderer.display_controls(self.encoding.chart_type)
+ )
+ self.encoding_widget.show_logarithmic_x_axis(
+ self.renderer.display_logarithmic_x_axis(self.encoding.chart_type)
+ )
+ self.encoding_widget.show_logarithmic_y_axis(
+ self.renderer.display_logarithmic_y_axis(self.encoding.chart_type)
+ )
if len(self.df) > 0:
self.renderer.render(self.df, self.encoding, self.to_display)
else:
with self.to_display:
- self.ipython_display.display(self.ipywidget_factory.get_html('No results.'))
+ self.ipython_display.display(
+ self.ipywidget_factory.get_html("No results.")
+ )
def _create_controls_widget(self):
# Create types of viz hbox
@@ -97,7 +116,9 @@ def _create_viz_types_buttons(self):
children = list()
if len(self.df) > 0:
- self.heading = self.ipywidget_factory.get_html('Type:', width='80px', height='32px')
+ self.heading = self.ipywidget_factory.get_html(
+ "Type:", width="80px", height="32px"
+ )
children.append(self.heading)
self._create_type_button(Encoding.chart_type_table, children)
@@ -130,6 +151,6 @@ def _convert_to_displayable_dataframe(df):
df = df.copy()
# Convert all booleans to string because Plotly doesn't know how to plot booleans,
# but it does know how to plot strings.
- bool_columns = list(df.select_dtypes(include=['bool']).columns)
+ bool_columns = list(df.select_dtypes(include=["bool"]).columns)
df[bool_columns] = df[bool_columns].astype(str)
return df
diff --git a/autovizwidget/autovizwidget/widget/encoding.py b/autovizwidget/autovizwidget/widget/encoding.py
index 9c778aaa5..f2f57114f 100644
--- a/autovizwidget/autovizwidget/widget/encoding.py
+++ b/autovizwidget/autovizwidget/widget/encoding.py
@@ -9,7 +9,13 @@ class Encoding(object):
chart_type_bar = "Bar"
chart_type_pie = "Pie"
chart_type_table = "Table"
- supported_chart_types = [chart_type_line, chart_type_area, chart_type_bar, chart_type_pie, chart_type_table]
+ supported_chart_types = [
+ chart_type_line,
+ chart_type_area,
+ chart_type_bar,
+ chart_type_pie,
+ chart_type_table,
+ ]
y_agg_avg = "Avg"
y_agg_min = "Min"
@@ -17,10 +23,24 @@ class Encoding(object):
y_agg_sum = "Sum"
y_agg_none = "None"
y_agg_count = "Count"
- supported_y_agg = [y_agg_avg, y_agg_min, y_agg_max, y_agg_sum, y_agg_none, y_agg_count]
-
- def __init__(self, chart_type=None, x=None, y=None, y_aggregation=None,
- logarithmic_x_axis=False, logarithmic_y_axis=False):
+ supported_y_agg = [
+ y_agg_avg,
+ y_agg_min,
+ y_agg_max,
+ y_agg_sum,
+ y_agg_none,
+ y_agg_count,
+ ]
+
+ def __init__(
+ self,
+ chart_type=None,
+ x=None,
+ y=None,
+ y_aggregation=None,
+ logarithmic_x_axis=False,
+ logarithmic_y_axis=False,
+ ):
self._chart_type = chart_type
self._x = x
self._y = y
diff --git a/autovizwidget/autovizwidget/widget/encodingwidget.py b/autovizwidget/autovizwidget/widget/encodingwidget.py
index 5690d85ec..38bb0af6d 100644
--- a/autovizwidget/autovizwidget/widget/encodingwidget.py
+++ b/autovizwidget/autovizwidget/widget/encodingwidget.py
@@ -17,12 +17,14 @@
class EncodingWidget(Box):
- def __init__(self, df, encoding, change_hook, ipywidget_factory=None, testing=False, **kwargs):
+ def __init__(
+ self, df, encoding, change_hook, ipywidget_factory=None, testing=False, **kwargs
+ ):
assert encoding is not None
assert df is not None
assert type(df) is pd.DataFrame
- kwargs['orientation'] = 'vertical'
+ kwargs["orientation"] = "vertical"
if not testing:
super(EncodingWidget, self).__init__((), **kwargs)
@@ -36,36 +38,43 @@ def __init__(self, df, encoding, change_hook, ipywidget_factory=None, testing=Fa
self.widget = self.ipywidget_factory.get_vbox()
- self.title = self.ipywidget_factory.get_html('Encoding:', width='148px', height='32px')
+ self.title = self.ipywidget_factory.get_html(
+ "Encoding:", width="148px", height="32px"
+ )
# X view
options_x_view = {text(i): text(i) for i in self.df.columns}
options_x_view["-"] = None
- self.x_view = self.ipywidget_factory.get_dropdown(options=options_x_view,
- description="X", value=self.encoding.x)
- self.x_view.on_trait_change(self._x_changed_callback, 'value')
+ self.x_view = self.ipywidget_factory.get_dropdown(
+ options=options_x_view, description="X", value=self.encoding.x
+ )
+ self.x_view.on_trait_change(self._x_changed_callback, "value")
self.x_view.layout.width = "200px"
# Y
options_y_view = {text(i): text(i) for i in self.df.columns}
options_y_view["-"] = None
- y_column_view = self.ipywidget_factory.get_dropdown(options=options_y_view,
- description="Y", value=self.encoding.y)
- y_column_view.on_trait_change(self._y_changed_callback, 'value')
+ y_column_view = self.ipywidget_factory.get_dropdown(
+ options=options_y_view, description="Y", value=self.encoding.y
+ )
+ y_column_view.on_trait_change(self._y_changed_callback, "value")
y_column_view.layout.width = "200px"
# Y aggregator
value_for_view = self._get_value_for_aggregation(self.encoding.y_aggregation)
self.y_agg_view = self.ipywidget_factory.get_dropdown(
- options={"-": Encoding.y_agg_none,
- Encoding.y_agg_avg: Encoding.y_agg_avg,
- Encoding.y_agg_min: Encoding.y_agg_min,
- Encoding.y_agg_max: Encoding.y_agg_max,
- Encoding.y_agg_sum: Encoding.y_agg_sum,
- Encoding.y_agg_count: Encoding.y_agg_count},
+ options={
+ "-": Encoding.y_agg_none,
+ Encoding.y_agg_avg: Encoding.y_agg_avg,
+ Encoding.y_agg_min: Encoding.y_agg_min,
+ Encoding.y_agg_max: Encoding.y_agg_max,
+ Encoding.y_agg_sum: Encoding.y_agg_sum,
+ Encoding.y_agg_count: Encoding.y_agg_count,
+ },
description="Func.",
- value=value_for_view)
- self.y_agg_view.on_trait_change(self._y_agg_changed_callback, 'value')
+ value=value_for_view,
+ )
+ self.y_agg_view.on_trait_change(self._y_agg_changed_callback, "value")
self.y_agg_view.layout.width = "200px"
# Y view
@@ -74,15 +83,23 @@ def __init__(self, df, encoding, change_hook, ipywidget_factory=None, testing=Fa
# Logarithmic X axis
self.logarithmic_x_axis = self.ipywidget_factory.get_checkbox(
- description="Log scale X", value=encoding.logarithmic_x_axis)
+ description="Log scale X", value=encoding.logarithmic_x_axis
+ )
self.logarithmic_x_axis.on_trait_change(self._logarithmic_x_callback, "value")
# Logarithmic Y axis
self.logarithmic_y_axis = self.ipywidget_factory.get_checkbox(
- description="Log scale Y", value=encoding.logarithmic_y_axis)
+ description="Log scale Y", value=encoding.logarithmic_y_axis
+ )
self.logarithmic_y_axis.on_trait_change(self._logarithmic_y_callback, "value")
- children = [self.title, self.x_view, self.y_view, self.logarithmic_x_axis, self.logarithmic_y_axis]
+ children = [
+ self.title,
+ self.x_view,
+ self.y_view,
+ self.logarithmic_x_axis,
+ self.logarithmic_y_axis,
+ ]
self.widget.children = children
self.children = [self.widget]
@@ -109,28 +126,28 @@ def _get_value_for_aggregation(self, y_aggregation):
return "none"
def _x_changed_callback(self, name, old_value, new_value):
- self.encoding.x = new_value
- return self.change_hook()
+ self.encoding.x = new_value
+ return self.change_hook()
def _y_changed_callback(self, name, old_value, new_value):
- self.encoding.y = new_value
- return self.change_hook()
+ self.encoding.y = new_value
+ return self.change_hook()
def _y_agg_changed_callback(self, name, old_value, new_value):
- if new_value == "none":
- self.encoding.y_aggregation = None
- else:
- self.encoding.y_aggregation = new_value
- return self.change_hook()
+ if new_value == "none":
+ self.encoding.y_aggregation = None
+ else:
+ self.encoding.y_aggregation = new_value
+ return self.change_hook()
def _logarithmic_x_callback(self, name, old_value, new_value):
- self.encoding.logarithmic_x_axis = new_value
- return self.change_hook()
+ self.encoding.logarithmic_x_axis = new_value
+ return self.change_hook()
def _logarithmic_y_callback(self, name, old_value, new_value):
- self.encoding.logarithmic_y_axis = new_value
- return self.change_hook()
-
+ self.encoding.logarithmic_y_axis = new_value
+ return self.change_hook()
+
def _widget_visible(self, widget, visible):
if visible:
widget.layout.display = "flex"
diff --git a/autovizwidget/autovizwidget/widget/utils.py b/autovizwidget/autovizwidget/widget/utils.py
index b79ebd91b..ec3fa55e3 100644
--- a/autovizwidget/autovizwidget/widget/utils.py
+++ b/autovizwidget/autovizwidget/widget/utils.py
@@ -15,16 +15,28 @@ def infer_vegalite_type(data):
typ = pd.api.types.infer_dtype(data)
- if typ in ['floating', 'mixed-integer-float', 'integer',
- 'mixed-integer', 'complex']:
- typecode = 'Q'
- elif typ in ['string', 'bytes', 'categorical', 'boolean', 'mixed', 'unicode']:
- typecode = 'N'
- elif typ in ['datetime', 'datetime64', 'timedelta',
- 'timedelta64', 'date', 'time', 'period']:
- typecode = 'T'
+ if typ in [
+ "floating",
+ "mixed-integer-float",
+ "integer",
+ "mixed-integer",
+ "complex",
+ ]:
+ typecode = "Q"
+ elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
+ typecode = "N"
+ elif typ in [
+ "datetime",
+ "datetime64",
+ "timedelta",
+ "timedelta64",
+ "date",
+ "time",
+ "period",
+ ]:
+ typecode = "T"
else:
- typecode = 'N'
+ typecode = "N"
return typecode
@@ -33,7 +45,7 @@ def _validate_custom_order(order):
assert len(order) == 4
list_to_check = list(order)
list_to_check.sort()
- assert list_to_check == ['N', 'O', 'Q', 'T']
+ assert list_to_check == ["N", "O", "Q", "T"]
def _classify_data_by_type(data, order, skip=None):
@@ -65,7 +77,7 @@ def select_x(data, order=None):
return None
if order is None:
- order = ['T', 'O', 'N', 'Q']
+ order = ["T", "O", "N", "Q"]
else:
_validate_custom_order(order)
@@ -96,7 +108,7 @@ def select_y(data, x_name, order=None, aggregator=None):
return None
if order is None:
- order = ['Q', 'O', 'N', 'T']
+ order = ["Q", "O", "N", "T"]
else:
_validate_custom_order(order)
@@ -115,6 +127,10 @@ def select_y(data, x_name, order=None, aggregator=None):
def display_dataframe(df):
selected_x = select_x(df)
selected_y = select_y(df, selected_x)
- encoding = Encoding(chart_type=Encoding.chart_type_table, x=selected_x, y=selected_y,
- y_aggregation=Encoding.y_agg_max)
+ encoding = Encoding(
+ chart_type=Encoding.chart_type_table,
+ x=selected_x,
+ y=selected_y,
+ y_aggregation=Encoding.y_agg_max,
+ )
return AutoVizWidget(df, encoding)
diff --git a/hdijupyterutils/hdijupyterutils/configuration.py b/hdijupyterutils/hdijupyterutils/configuration.py
index a18a6a698..0cb47bec6 100644
--- a/hdijupyterutils/hdijupyterutils/configuration.py
+++ b/hdijupyterutils/hdijupyterutils/configuration.py
@@ -11,6 +11,7 @@ def with_override(overrides, path, fsrw_class=None):
"""A decorator which first initializes the overrided configurations,
then checks the global overrided defaults for the given configuration,
calling the function to get the default result otherwise."""
+
def ret(f):
def wrapped_f(*args):
# Can access overrides and path here
@@ -20,22 +21,22 @@ def wrapped_f(*args):
return overrides[name]
else:
return f(*args)
-
+
# Hack! We do this so that we can query the .__name__ of the function
# later to get the name of the configuration dynamically, e.g. for unit tests
wrapped_f.__name__ = f.__name__
return wrapped_f
-
+
return ret
-
+
def override(overrides, path, config, value, fsrw_class=None):
"""Given a string representing a configuration and a value for that configuration,
override the configuration. Initialize the overrided configuration beforehand."""
_initialize(overrides, path, fsrw_class)
overrides[config] = value
-
+
def override_all(overrides, new_overrides):
"""Given a dictionary representing the overrided defaults for this
configuration, initialize the global configuration."""
@@ -50,20 +51,20 @@ def _initialize(overrides, path, fsrw_class):
if not overrides:
new_overrides = _load(path, fsrw_class)
override_all(overrides, new_overrides)
-
-
+
+
def _load(path, fsrw_class=None):
"""Returns a dictionary of configuration by reading from the configuration
file."""
if fsrw_class is None:
fsrw_class = FileSystemReaderWriter
-
+
config_file = fsrw_class(path)
config_file.ensure_file_exists()
config_text = config_file.read_lines()
- line = u"".join(config_text).strip()
-
- if line == u"":
+ line = "".join(config_text).strip()
+
+ if line == "":
overrides = {}
else:
overrides = json.loads(line)
diff --git a/hdijupyterutils/hdijupyterutils/constants.py b/hdijupyterutils/hdijupyterutils/constants.py
index fdd4d8a7a..08a18b9b8 100644
--- a/hdijupyterutils/hdijupyterutils/constants.py
+++ b/hdijupyterutils/hdijupyterutils/constants.py
@@ -1,6 +1,6 @@
-LOGGING_CONFIG_CLASS_NAME = u"hdijupyterutils.filehandler.MagicsFileHandler"
+LOGGING_CONFIG_CLASS_NAME = "hdijupyterutils.filehandler.MagicsFileHandler"
-EVENTS_HANDLER_CLASS_NAME = u"hdijupyterutils.eventshandler.EventsHandler"
+EVENTS_HANDLER_CLASS_NAME = "hdijupyterutils.eventshandler.EventsHandler"
INSTANCE_ID = "InstanceId"
TIMESTAMP = "Timestamp"
EVENT_NAME = "EventName"
diff --git a/hdijupyterutils/hdijupyterutils/filehandler.py b/hdijupyterutils/hdijupyterutils/filehandler.py
index e5a5fc0f9..6f74580b9 100644
--- a/hdijupyterutils/hdijupyterutils/filehandler.py
+++ b/hdijupyterutils/hdijupyterutils/filehandler.py
@@ -6,15 +6,19 @@
class MagicsFileHandler(logging.FileHandler):
"""The default logging handler used by the magics; this behavior can be overridden by modifying the config file"""
+
def __init__(self, **kwargs):
# Simply invokes the behavior of the superclass, but sets the filename keyword argument if it's not already set.
- if 'filename' in kwargs:
+ if "filename" in kwargs:
super(MagicsFileHandler, self).__init__(**kwargs)
else:
- magics_home_path = kwargs.pop(u"home_path")
+ magics_home_path = kwargs.pop("home_path")
logs_folder_name = "logs"
log_file_name = "log_{}.log".format(get_instance_id())
- directory = FileSystemReaderWriter(join_paths(magics_home_path, logs_folder_name))
+ directory = FileSystemReaderWriter(
+ join_paths(magics_home_path, logs_folder_name)
+ )
directory.ensure_path_exists()
- super(MagicsFileHandler, self).__init__(filename=join_paths(directory.path, log_file_name), **kwargs)
-
+ super(MagicsFileHandler, self).__init__(
+ filename=join_paths(directory.path, log_file_name), **kwargs
+ )
diff --git a/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py b/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py
index e948cac2b..5ede09cf6 100644
--- a/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py
+++ b/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py
@@ -4,9 +4,9 @@
class FileSystemReaderWriter(object):
-
def __init__(self, path):
from .utils import expand_path
+
assert path is not None
self.path = expand_path(path)
@@ -16,7 +16,7 @@ def ensure_path_exists(self):
def ensure_file_exists(self):
self._ensure_path_exists(os.path.dirname(self.path))
if not os.path.exists(self.path):
- open(self.path, 'w').close()
+ open(self.path, "w").close()
def read_lines(self):
if os.path.isfile(self.path):
diff --git a/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py b/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py
index 23164872f..249857f1a 100644
--- a/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py
+++ b/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py
@@ -1,7 +1,20 @@
# Copyright (c) 2015 aggftw@gmail.com
# Distributed under the terms of the Modified BSD License.
-from ipywidgets import VBox, Output, Button, HTML, HBox, Dropdown, Checkbox, ToggleButtons, Text, Textarea, Tab, Password
+from ipywidgets import (
+ VBox,
+ Output,
+ Button,
+ HTML,
+ HBox,
+ Dropdown,
+ Checkbox,
+ ToggleButtons,
+ Text,
+ Textarea,
+ Tab,
+ Password,
+)
class IpyWidgetFactory(object):
diff --git a/hdijupyterutils/hdijupyterutils/log.py b/hdijupyterutils/hdijupyterutils/log.py
index 3e9b3dfaf..943403325 100644
--- a/hdijupyterutils/hdijupyterutils/log.py
+++ b/hdijupyterutils/hdijupyterutils/log.py
@@ -9,9 +9,10 @@
class Log(object):
"""Logger for magics. A small wrapper class around the configured logger described in the configuration file"""
+
def __init__(self, logger_name, logging_config, caller_name):
logging.config.dictConfig(logging_config)
-
+
assert caller_name is not None
self._caller_name = caller_name
self.logger_name = logger_name
@@ -30,30 +31,30 @@ def _getLogger(self):
self.logger = logging.getLogger(self.logger_name)
def _transform_log_message(self, message):
- return u'{}\t{}'.format(self._caller_name, message)
+ return "{}\t{}".format(self._caller_name, message)
def logging_config():
return {
- u"version": 1,
- u"formatters": {
- u"magicsFormatter": {
- u"format": u"%(asctime)s\t%(levelname)s\t%(message)s",
- u"datefmt": u""
+ "version": 1,
+ "formatters": {
+ "magicsFormatter": {
+ "format": "%(asctime)s\t%(levelname)s\t%(message)s",
+ "datefmt": "",
}
},
- u"handlers": {
- u"magicsHandler": {
- u"class": LOGGING_CONFIG_CLASS_NAME,
- u"formatter": u"magicsFormatter",
- u"home_path": "~/.hdijupyterutils"
+ "handlers": {
+ "magicsHandler": {
+ "class": LOGGING_CONFIG_CLASS_NAME,
+ "formatter": "magicsFormatter",
+ "home_path": "~/.hdijupyterutils",
}
},
- u"loggers": {
- u"magicsLogger": {
- u"handlers": [u"magicsHandler"],
- u"level": u"DEBUG",
- u"propagate": 0
+ "loggers": {
+ "magicsLogger": {
+ "handlers": ["magicsHandler"],
+ "level": "DEBUG",
+ "propagate": 0,
}
- }
- }
+ },
+ }
diff --git a/hdijupyterutils/hdijupyterutils/tests/test_configuration.py b/hdijupyterutils/hdijupyterutils/tests/test_configuration.py
index fb43024db..d391cdd80 100644
--- a/hdijupyterutils/hdijupyterutils/tests/test_configuration.py
+++ b/hdijupyterutils/hdijupyterutils/tests/test_configuration.py
@@ -11,7 +11,7 @@
path = "~/.testing/config.json"
original_value = 0
-
+
def module_override(config, value):
global d, path
override(d, path, config, value)
@@ -27,35 +27,35 @@ def module_override_all(obj):
def my_config():
global original_value
return original_value
-
-
+
+
@with_override(d, path)
def my_config_2():
global original_value
- return original_value
-
-
+ return original_value
+
+
# Test helper functions
def _setup():
module_override_all({})
-
-
+
+
def _teardown():
module_override_all({})
-
+
# Unit tests begin
@with_setup(_setup, _teardown)
def test_original_value_without_overrides():
assert_equals(original_value, my_config())
-
+
@with_setup(_setup, _teardown)
def test_original_value_with_overrides():
new_value = 2
module_override(my_config.__name__, new_value)
assert_equals(new_value, my_config())
-
+
@with_setup(_setup, _teardown)
def test_original_values_when_others_override():
@@ -63,15 +63,15 @@ def test_original_values_when_others_override():
module_override(my_config.__name__, new_value)
assert_equals(new_value, my_config())
assert_equals(original_value, my_config_2())
-
-
+
+
@with_setup(_setup, _teardown)
def test_resetting_values_when_others_override():
new_value = 2
module_override(my_config.__name__, new_value)
assert_equals(new_value, my_config())
assert_equals(original_value, my_config_2())
-
+
# Reset
module_override_all({})
assert_equals(original_value, my_config())
diff --git a/hdijupyterutils/hdijupyterutils/tests/test_events.py b/hdijupyterutils/hdijupyterutils/tests/test_events.py
index 614f01d09..1010806d2 100644
--- a/hdijupyterutils/hdijupyterutils/tests/test_events.py
+++ b/hdijupyterutils/hdijupyterutils/tests/test_events.py
@@ -30,24 +30,26 @@ def test_send_to_handler():
events.send_to_handler(kwargs_list)
events.handler.handle_event.assert_called_once_with(expected_kwargs_list)
-
-
+
+
@with_setup(_setup, _teardown)
@raises(AssertionError)
def test_send_to_handler_asserts_less_than_12():
- kwargs_list = [(TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp),
- (TIMESTAMP, time_stamp)]
+ kwargs_list = [
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ (TIMESTAMP, time_stamp),
+ ]
events.send_to_handler(kwargs_list)
diff --git a/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py b/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py
index b5de9dd4c..831221971 100644
--- a/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py
+++ b/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py
@@ -3,20 +3,22 @@
from mock import MagicMock
import sys
+
def test_stdout_flush():
ipython_shell = MagicMock()
ipython_display = IpythonDisplay()
ipython_display._ipython_shell = ipython_shell
sys.stdout = MagicMock()
- ipython_display.write(u'Testing Stdout Flush รจ')
+ ipython_display.write("Testing Stdout Flush รจ")
assert sys.stdout.flush.call_count == 1
+
def test_stderr_flush():
ipython_shell = MagicMock()
ipython_display = IpythonDisplay()
ipython_display._ipython_shell = ipython_shell
sys.stderr = MagicMock()
- ipython_display.send_error(u'Testing Stderr Flush รจ')
+ ipython_display.send_error("Testing Stderr Flush รจ")
assert sys.stderr.flush.call_count == 1
diff --git a/hdijupyterutils/hdijupyterutils/tests/test_logger.py b/hdijupyterutils/hdijupyterutils/tests/test_logger.py
index 646eb3ae7..92a1785fd 100644
--- a/hdijupyterutils/hdijupyterutils/tests/test_logger.py
+++ b/hdijupyterutils/hdijupyterutils/tests/test_logger.py
@@ -11,7 +11,7 @@ def get_logging_config():
def test_log_init():
logging_config = get_logging_config()
- logger = Log('name', logging_config, 'something')
+ logger = Log("name", logging_config, "something")
assert isinstance(logger.logger, logging.Logger)
@@ -22,49 +22,49 @@ def __init__(self):
self.level = self.message = None
def debug(self, message):
- self.level, self.message = 'DEBUG', message
-
+ self.level, self.message = "DEBUG", message
+
def error(self, message):
- self.level, self.message = 'ERROR', message
+ self.level, self.message = "ERROR", message
def info(self, message):
- self.level, self.message = 'INFO', message
+ self.level, self.message = "INFO", message
class MockLog(Log):
def __init__(self, name):
logging_config = get_logging_config()
super(MockLog, self).__init__(name, logging_config, name)
-
+
def _getLogger(self):
- self.logger = MockLogger()
+ self.logger = MockLogger()
def test_log_returnvalue():
- logger = MockLog('test2')
+ logger = MockLog("test2")
assert isinstance(logger.logger, MockLogger)
mock = logger.logger
- logger.debug('word1')
- assert mock.level == 'DEBUG'
- assert_equals(mock.message, 'test2\tword1')
- logger.error('word2')
- assert mock.level == 'ERROR'
- assert mock.message == 'test2\tword2'
- logger.info('word3')
- assert mock.level == 'INFO'
- assert mock.message == 'test2\tword3'
+ logger.debug("word1")
+ assert mock.level == "DEBUG"
+ assert_equals(mock.message, "test2\tword1")
+ logger.error("word2")
+ assert mock.level == "ERROR"
+ assert mock.message == "test2\tword2"
+ logger.info("word3")
+ assert mock.level == "INFO"
+ assert mock.message == "test2\tword3"
def test_log_unicode():
- logger = MockLog('test2')
+ logger = MockLog("test2")
assert isinstance(logger.logger, MockLogger)
mock = logger.logger
- logger.debug(u'word1รจ')
- assert mock.level == 'DEBUG'
- assert mock.message == u'test2\tword1รจ'
- logger.error(u'word2รจ')
- assert mock.level == 'ERROR'
- assert mock.message == u'test2\tword2รจ'
- logger.info(u'word3รจ')
- assert mock.level == 'INFO'
- assert mock.message == u'test2\tword3รจ'
+ logger.debug("word1รจ")
+ assert mock.level == "DEBUG"
+ assert mock.message == "test2\tword1รจ"
+ logger.error("word2รจ")
+ assert mock.level == "ERROR"
+ assert mock.message == "test2\tword2รจ"
+ logger.info("word3รจ")
+ assert mock.level == "INFO"
+ assert mock.message == "test2\tword3รจ"
diff --git a/sparkmagic/sparkmagic/auth/basic.py b/sparkmagic/sparkmagic/auth/basic.py
index edecf67d0..381f1c21e 100644
--- a/sparkmagic/sparkmagic/auth/basic.py
+++ b/sparkmagic/sparkmagic/auth/basic.py
@@ -5,8 +5,10 @@
from requests.auth import HTTPBasicAuth
from .customauth import Authenticator
+
class Basic(HTTPBasicAuth, Authenticator):
"""Basic Access authenticator for SparkMagic"""
+
def __init__(self, parsed_attributes=None):
"""Initializes the Authenticator with the attributes in the attributes
parsed from a %spark magic command if applicable, or with default values
@@ -18,15 +20,17 @@ def __init__(self, parsed_attributes=None):
is created from parsing %spark magic command.
"""
if parsed_attributes is not None:
- if parsed_attributes.user == '' or parsed_attributes.password == '':
- new_exc = BadUserDataException("Need to supply username and password arguments for "\
- "Basic Access Authentication. (e.g. -a username -p password).")
+ if parsed_attributes.user == "" or parsed_attributes.password == "":
+ new_exc = BadUserDataException(
+ "Need to supply username and password arguments for "
+ "Basic Access Authentication. (e.g. -a username -p password)."
+ )
raise new_exc
self.username = parsed_attributes.user
self.password = parsed_attributes.password
else:
- self.username = 'username'
- self.password = 'password'
+ self.username = "username"
+ self.password = "password"
HTTPBasicAuth.__init__(self, self.username, self.password)
Authenticator.__init__(self, parsed_attributes)
@@ -42,15 +46,11 @@ def get_widgets(self, widget_width):
ipywidget_factory = IpyWidgetFactory()
self.user_widget = ipywidget_factory.get_text(
- description='Username:',
- value=self.username,
- width=widget_width
+ description="Username:", value=self.username, width=widget_width
)
self.password_widget = ipywidget_factory.get_password(
- description='Password:',
- value=self.password,
- width=widget_width
+ description="Password:", value=self.password, width=widget_width
)
widgets = [self.user_widget, self.password_widget]
@@ -65,8 +65,11 @@ def update_with_widget_values(self):
def __eq__(self, other):
if not isinstance(other, Basic):
return False
- return self.url == other.url and self.username == other.username and \
- self.password == other.password
+ return (
+ self.url == other.url
+ and self.username == other.username
+ and self.password == other.password
+ )
def __call__(self, request):
return HTTPBasicAuth.__call__(self, request)
diff --git a/sparkmagic/sparkmagic/auth/customauth.py b/sparkmagic/sparkmagic/auth/customauth.py
index 3c8dfb6c7..b84261403 100644
--- a/sparkmagic/sparkmagic/auth/customauth.py
+++ b/sparkmagic/sparkmagic/auth/customauth.py
@@ -3,6 +3,7 @@
from hdijupyterutils.ipywidgetfactory import IpyWidgetFactory
from sparkmagic.utils.constants import WIDGET_WIDTH
+
class Authenticator(object):
"""Base Authenticator for all Sparkmagic authentication providers."""
@@ -19,7 +20,7 @@ def __init__(self, parsed_attributes=None):
if parsed_attributes is not None:
self.url = parsed_attributes.url
else:
- self.url = 'http://example.com/livy'
+ self.url = "http://example.com/livy"
self.widgets = self.get_widgets(WIDGET_WIDTH)
def get_widgets(self, widget_width):
@@ -34,9 +35,7 @@ def get_widgets(self, widget_width):
ipywidget_factory = IpyWidgetFactory()
self.address_widget = ipywidget_factory.get_text(
- description='Address:',
- value='http://example.com/livy',
- width=widget_width
+ description="Address:", value="http://example.com/livy", width=widget_width
)
widgets = [self.address_widget]
return widgets
diff --git a/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py b/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py
index 61e167272..b0557c417 100644
--- a/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py
+++ b/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py
@@ -5,9 +5,16 @@
class AbstractMenuWidget(Box):
- def __init__(self, spark_controller, ipywidget_factory=None, ipython_display=None,
- nested_widget_mode=False, testing=False, **kwargs):
- kwargs['orientation'] = 'vertical'
+ def __init__(
+ self,
+ spark_controller,
+ ipywidget_factory=None,
+ ipython_display=None,
+ nested_widget_mode=False,
+ testing=False,
+ **kwargs
+ ):
+ kwargs["orientation"] = "vertical"
if not testing:
super(AbstractMenuWidget, self).__init__((), **kwargs)
diff --git a/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py b/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py
index ad460b36c..c795c80ab 100644
--- a/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py
+++ b/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py
@@ -8,48 +8,64 @@
class AddEndpointWidget(AbstractMenuWidget):
-
- def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints, endpoints_dropdown_widget,
- refresh_method):
+ def __init__(
+ self,
+ spark_controller,
+ ipywidget_factory,
+ ipython_display,
+ endpoints,
+ endpoints_dropdown_widget,
+ refresh_method,
+ ):
# This is nested
- super(AddEndpointWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True)
+ super(AddEndpointWidget, self).__init__(
+ spark_controller, ipywidget_factory, ipython_display, True
+ )
self.endpoints = endpoints
self.endpoints_dropdown_widget = endpoints_dropdown_widget
self.refresh_method = refresh_method
- #map auth class path string to the instance of the class.
+ # map auth class path string to the instance of the class.
self.auth_instances = {}
for auth in conf.authenticators().values():
- module, class_name = (auth).rsplit('.', 1)
+ module, class_name = (auth).rsplit(".", 1)
events_handler_module = importlib.import_module(module)
auth_class = getattr(events_handler_module, class_name)
self.auth_instances[auth] = auth_class()
self.auth_type = self.ipywidget_factory.get_dropdown(
- options=conf.authenticators(),
- description=u"Auth type:"
+ options=conf.authenticators(), description="Auth type:"
)
- #combine all authentication instance's widgets into one list to pass to self.children.
+ # combine all authentication instance's widgets into one list to pass to self.children.
self.all_widgets = list()
for _class, instance in self.auth_instances.items():
for widget in instance.widgets:
- if _class == self.auth_type.value:
- widget.layout.display = 'flex'
+ if _class == self.auth_type.value:
+ widget.layout.display = "flex"
self.auth = instance
else:
- widget.layout.display = 'none'
+ widget.layout.display = "none"
self.all_widgets.append(widget)
# Submit widget
self.submit_widget = self.ipywidget_factory.get_submit_button(
- description='Add endpoint'
+ description="Add endpoint"
)
self.auth_type.on_trait_change(self._update_auth)
- self.children = [self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH), self.auth_type] + self.all_widgets \
- + [self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH), self.submit_widget]
+ self.children = (
+ [
+ self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH),
+ self.auth_type,
+ ]
+ + self.all_widgets
+ + [
+ self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH),
+ self.submit_widget,
+ ]
+ )
for child in self.children:
child.parent_widget = self
@@ -77,7 +93,7 @@ def _update_auth(self):
Create an instance of the chosen auth type maps to in the config file.
"""
for widget in self.auth.widgets:
- widget.layout.display = 'none'
+ widget.layout.display = "none"
self.auth = self.auth_instances.get(self.auth_type.value)
for widget in self.auth.widgets:
- widget.layout.display = 'flex'
+ widget.layout.display = "flex"
diff --git a/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py b/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py
index af534a7e5..d1d7aaa7d 100644
--- a/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py
+++ b/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py
@@ -8,33 +8,46 @@
class CreateSessionWidget(AbstractMenuWidget):
- def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints_dropdown_widget, refresh_method):
+ def __init__(
+ self,
+ spark_controller,
+ ipywidget_factory,
+ ipython_display,
+ endpoints_dropdown_widget,
+ refresh_method,
+ ):
# This is nested
- super(CreateSessionWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True)
+ super(CreateSessionWidget, self).__init__(
+ spark_controller, ipywidget_factory, ipython_display, True
+ )
self.refresh_method = refresh_method
self.endpoints_dropdown_widget = endpoints_dropdown_widget
self.session_widget = self.ipywidget_factory.get_text(
- description='Name:',
- value='session-name'
+ description="Name:", value="session-name"
)
self.lang_widget = self.ipywidget_factory.get_toggle_buttons(
- description='Language:',
+ description="Language:",
options=[LANG_SCALA, LANG_PYTHON],
)
self.properties = self.ipywidget_factory.get_text(
- description='Properties:',
- value=json.dumps(conf.session_configs())
+ description="Properties:", value=json.dumps(conf.session_configs())
)
self.submit_widget = self.ipywidget_factory.get_submit_button(
- description='Create Session'
+ description="Create Session"
)
- self.children = [self.ipywidget_factory.get_html(value="
", width="600px"), self.endpoints_dropdown_widget,
- self.session_widget, self.lang_widget, self.properties,
- self.ipywidget_factory.get_html(value="
", width="600px"), self.submit_widget]
+ self.children = [
+ self.ipywidget_factory.get_html(value="
", width="600px"),
+ self.endpoints_dropdown_widget,
+ self.session_widget,
+ self.lang_widget,
+ self.properties,
+ self.ipywidget_factory.get_html(value="
", width="600px"),
+ self.submit_widget,
+ ]
for child in self.children:
child.parent_widget = self
@@ -43,9 +56,13 @@ def run(self):
try:
properties_json = self.properties.value
if properties_json.strip() != "":
- conf.override(conf.session_configs.__name__, json.loads(self.properties.value))
+ conf.override(
+ conf.session_configs.__name__, json.loads(self.properties.value)
+ )
except ValueError as e:
- self.ipython_display.send_error("Session properties must be a valid JSON string. Error:\n{}".format(e))
+ self.ipython_display.send_error(
+ "Session properties must be a valid JSON string. Error:\n{}".format(e)
+ )
return
endpoint = self.endpoints_dropdown_widget.value
@@ -57,13 +74,17 @@ def run(self):
try:
self.spark_controller.add_session(alias, endpoint, skip, properties)
except ValueError as e:
- self.ipython_display.send_error("""Could not add session with
+ self.ipython_display.send_error(
+ """Could not add session with
name:
{}
properties:
{}
-due to error: '{}'""".format(alias, properties, e))
+due to error: '{}'""".format(
+ alias, properties, e
+ )
+ )
return
self.refresh_method()
diff --git a/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py b/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py
index 7e3ffd298..9e990ce57 100644
--- a/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py
+++ b/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py
@@ -12,11 +12,17 @@
class MagicsControllerWidget(AbstractMenuWidget):
- def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints=None):
- super(MagicsControllerWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display)
+ def __init__(
+ self, spark_controller, ipywidget_factory, ipython_display, endpoints=None
+ ):
+ super(MagicsControllerWidget, self).__init__(
+ spark_controller, ipywidget_factory, ipython_display
+ )
if endpoints is None:
- endpoints = {endpoint.url: endpoint for endpoint in self._get_default_endpoints()}
+ endpoints = {
+ endpoint.url: endpoint for endpoint in self._get_default_endpoints()
+ }
self.endpoints = endpoints
self._refresh()
@@ -29,37 +35,73 @@ def _get_default_endpoints():
default_endpoints = set()
for kernel_type in LANGS_SUPPORTED:
- endpoint_config = getattr(conf, 'kernel_%s_credentials' % kernel_type)()
- if all([p in endpoint_config for p in ["url", "password", "username"]]) and endpoint_config["url"] != "":
+ endpoint_config = getattr(conf, "kernel_%s_credentials" % kernel_type)()
+ if (
+ all([p in endpoint_config for p in ["url", "password", "username"]])
+ and endpoint_config["url"] != ""
+ ):
user = endpoint_config["username"]
passwd = endpoint_config["password"]
- args = Namespace(user=user, password=passwd, auth=endpoint_config.get("auth", None), url=endpoint_config.get("url", None))
+ args = Namespace(
+ user=user,
+ password=passwd,
+ auth=endpoint_config.get("auth", None),
+ url=endpoint_config.get("url", None),
+ )
auth_instance = initialize_auth(args)
- default_endpoints.add(Endpoint(
- auth=auth_instance,
- url=endpoint_config["url"],
- implicitly_added=True))
+ default_endpoints.add(
+ Endpoint(
+ auth=auth_instance,
+ url=endpoint_config["url"],
+ implicitly_added=True,
+ )
+ )
return default_endpoints
def _refresh(self):
self.endpoints_dropdown_widget = self.ipywidget_factory.get_dropdown(
- description="Endpoint:",
- options=self.endpoints
+ description="Endpoint:", options=self.endpoints
)
- self.manage_session = ManageSessionWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display,
- self._refresh)
- self.create_session = CreateSessionWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display,
- self.endpoints_dropdown_widget, self._refresh)
- self.add_endpoint = AddEndpointWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display,
- self.endpoints, self.endpoints_dropdown_widget, self._refresh)
- self.manage_endpoint = ManageEndpointWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display,
- self.endpoints, self._refresh)
+ self.manage_session = ManageSessionWidget(
+ self.spark_controller,
+ self.ipywidget_factory,
+ self.ipython_display,
+ self._refresh,
+ )
+ self.create_session = CreateSessionWidget(
+ self.spark_controller,
+ self.ipywidget_factory,
+ self.ipython_display,
+ self.endpoints_dropdown_widget,
+ self._refresh,
+ )
+ self.add_endpoint = AddEndpointWidget(
+ self.spark_controller,
+ self.ipywidget_factory,
+ self.ipython_display,
+ self.endpoints,
+ self.endpoints_dropdown_widget,
+ self._refresh,
+ )
+ self.manage_endpoint = ManageEndpointWidget(
+ self.spark_controller,
+ self.ipywidget_factory,
+ self.ipython_display,
+ self.endpoints,
+ self._refresh,
+ )
- self.tabs = self.ipywidget_factory.get_tab(children=[self.manage_session, self.create_session,
- self.add_endpoint, self.manage_endpoint])
+ self.tabs = self.ipywidget_factory.get_tab(
+ children=[
+ self.manage_session,
+ self.create_session,
+ self.add_endpoint,
+ self.manage_endpoint,
+ ]
+ )
self.tabs.set_title(0, "Manage Sessions")
self.tabs.set_title(1, "Create Session")
self.tabs.set_title(2, "Add Endpoint")
diff --git a/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py b/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py
index c210d6264..1020bfdf0 100644
--- a/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py
+++ b/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py
@@ -6,9 +6,18 @@
class ManageEndpointWidget(AbstractMenuWidget):
- def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints, refresh_method):
+ def __init__(
+ self,
+ spark_controller,
+ ipywidget_factory,
+ ipython_display,
+ endpoints,
+ refresh_method,
+ ):
# This is nested
- super(ManageEndpointWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True)
+ super(ManageEndpointWidget, self).__init__(
+ spark_controller, ipywidget_factory, ipython_display, True
+ )
self.logger = SparkLog("ManageEndpointWidget")
self.endpoints = endpoints
@@ -24,7 +33,9 @@ def run(self):
def get_existing_endpoint_widgets(self):
endpoint_widgets = []
- endpoint_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px"))
+ endpoint_widgets.append(
+ self.ipywidget_factory.get_html(value="
", width="600px")
+ )
if len(self.endpoints) > 0:
# Header
@@ -40,11 +51,20 @@ def get_existing_endpoint_widgets(self):
if not endpoint.implicitly_added:
raise
else:
- self.logger.info("Failed to connect to implicitly-defined endpoint at: %s" % url)
-
- endpoint_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px"))
+ self.logger.info(
+ "Failed to connect to implicitly-defined endpoint at: %s"
+ % url
+ )
+
+ endpoint_widgets.append(
+ self.ipywidget_factory.get_html(value="
", width="600px")
+ )
else:
- endpoint_widgets.append(self.ipywidget_factory.get_html(value="No endpoints yet.", width="600px"))
+ endpoint_widgets.append(
+ self.ipywidget_factory.get_html(
+ value="No endpoints yet.", width="600px"
+ )
+ )
return endpoint_widgets
@@ -63,7 +83,9 @@ def get_endpoint_widget(self, url, endpoint):
hbox_outter_children.append(vbox_left)
hbox_outter_children.append(cleanup_w)
except ValueError as e:
- hbox_outter_children.append(self.ipywidget_factory.get_html(value=str(e), width=width))
+ hbox_outter_children.append(
+ self.ipywidget_factory.get_html(value=str(e), width=width)
+ )
hbox_outter_children.append(self.get_delete_button_endpoint(url, endpoint))
hbox_outter.children = hbox_outter_children
@@ -76,7 +98,9 @@ def get_endpoint_left(self, endpoint, url):
# 400 px
info = self.get_info_endpoint_widget(endpoint, url)
delete_session_number = self.get_delete_session_endpoint_widget(url, endpoint)
- vbox_left = self.ipywidget_factory.get_vbox(children=[info, delete_session_number], width="400px")
+ vbox_left = self.ipywidget_factory.get_vbox(
+ children=[info, delete_session_number], width="400px"
+ )
return vbox_left
def get_cleanup_button_endpoint(self, url, endpoint):
@@ -84,7 +108,9 @@ def cleanup_on_click(button):
try:
self.spark_controller.cleanup_endpoint(endpoint)
except ValueError as e:
- self.ipython_display.send_error("Could not clean up endpoint due to error: {}".format(e))
+ self.ipython_display.send_error(
+ "Could not clean up endpoint due to error: {}".format(e)
+ )
return
self.ipython_display.writeln("Cleaned up endpoint {}".format(url))
self.refresh_method()
@@ -105,7 +131,9 @@ def delete_on_click(button):
return delete_w
def get_delete_session_endpoint_widget(self, url, endpoint):
- session_text = self.ipywidget_factory.get_text(description="Session to delete:", value="0", width="50px")
+ session_text = self.ipywidget_factory.get_text(
+ description="Session to delete:", value="0", width="50px"
+ )
def delete_endpoint(button):
try:
@@ -120,7 +148,9 @@ def delete_endpoint(button):
button = self.ipywidget_factory.get_button(description="Delete")
button.on_click(delete_endpoint)
- return self.ipywidget_factory.get_hbox(children=[session_text, button], width="152px")
+ return self.ipywidget_factory.get_hbox(
+ children=[session_text, button], width="152px"
+ )
def get_info_endpoint_widget(self, endpoint, url):
# 400 px
@@ -129,7 +159,9 @@ def get_info_endpoint_widget(self, endpoint, url):
info_sessions = self.spark_controller.get_all_sessions_endpoint_info(endpoint)
if len(info_sessions) > 0:
- text = "{}:
{}".format(url, "* {}".format("
* ".join(info_sessions)))
+ text = "{}:
{}".format(
+ url, "* {}".format("
* ".join(info_sessions))
+ )
else:
text = "No sessions on this endpoint."
diff --git a/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py b/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py
index 268f049ab..ebd38951d 100644
--- a/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py
+++ b/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py
@@ -4,9 +4,13 @@
class ManageSessionWidget(AbstractMenuWidget):
- def __init__(self, spark_controller, ipywidget_factory, ipython_display, refresh_method):
+ def __init__(
+ self, spark_controller, ipywidget_factory, ipython_display, refresh_method
+ ):
# This is nested
- super(ManageSessionWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True)
+ super(ManageSessionWidget, self).__init__(
+ spark_controller, ipywidget_factory, ipython_display, True
+ )
self.refresh_method = refresh_method
@@ -20,34 +24,55 @@ def run(self):
def get_existing_session_widgets(self):
session_widgets = []
- session_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px"))
+ session_widgets.append(
+ self.ipywidget_factory.get_html(value="
", width="600px")
+ )
client_dict = self.spark_controller.get_managed_clients()
if len(client_dict) > 0:
# Header
header = self.get_session_widget("Name", "Id", "Kind", "State", False)
session_widgets.append(header)
- session_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px"))
+ session_widgets.append(
+ self.ipywidget_factory.get_html(value="
", width="600px")
+ )
# Sessions
for name, session in client_dict.items():
- session_widgets.append(self.get_session_widget(name, session.id, session.kind, session.status))
-
- session_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px"))
+ session_widgets.append(
+ self.get_session_widget(
+ name, session.id, session.kind, session.status
+ )
+ )
+
+ session_widgets.append(
+ self.ipywidget_factory.get_html(value="
", width="600px")
+ )
else:
- session_widgets.append(self.ipywidget_factory.get_html(value="No sessions yet.", width="600px"))
+ session_widgets.append(
+ self.ipywidget_factory.get_html(value="No sessions yet.", width="600px")
+ )
return session_widgets
def get_session_widget(self, name, session_id, kind, state, button=True):
hbox = self.ipywidget_factory.get_hbox()
- name_w = self.ipywidget_factory.get_html(value=name, width="200px", padding="4px")
- id_w = self.ipywidget_factory.get_html(value=str(session_id), width="100px", padding="4px")
- kind_w = self.ipywidget_factory.get_html(value=kind, width="100px", padding="4px")
- state_w = self.ipywidget_factory.get_html(value=state, width="100px", padding="4px")
+ name_w = self.ipywidget_factory.get_html(
+ value=name, width="200px", padding="4px"
+ )
+ id_w = self.ipywidget_factory.get_html(
+ value=str(session_id), width="100px", padding="4px"
+ )
+ kind_w = self.ipywidget_factory.get_html(
+ value=kind, width="100px", padding="4px"
+ )
+ state_w = self.ipywidget_factory.get_html(
+ value=state, width="100px", padding="4px"
+ )
if button:
+
def delete_on_click(button):
self.spark_controller.delete_session_by_name(name)
self.refresh_method()
@@ -55,7 +80,9 @@ def delete_on_click(button):
delete_w = self.ipywidget_factory.get_button(description="Delete")
delete_w.on_click(delete_on_click)
else:
- delete_w = self.ipywidget_factory.get_html(value="", width="100px", padding="4px")
+ delete_w = self.ipywidget_factory.get_html(
+ value="", width="100px", padding="4px"
+ )
hbox.children = [name_w, id_w, kind_w, state_w, delete_w]
diff --git a/sparkmagic/sparkmagic/kernels/kernelmagics.py b/sparkmagic/sparkmagic/kernels/kernelmagics.py
index b22771994..1269947ad 100644
--- a/sparkmagic/sparkmagic/kernels/kernelmagics.py
+++ b/sparkmagic/sparkmagic/kernels/kernelmagics.py
@@ -15,31 +15,52 @@
import sparkmagic.utils.configuration as conf
from sparkmagic.utils.configuration import get_livy_kind
from sparkmagic.utils import constants
-from sparkmagic.utils.utils import parse_argstring_or_throw, get_coerce_value, initialize_auth, Namespace
+from sparkmagic.utils.utils import (
+ parse_argstring_or_throw,
+ get_coerce_value,
+ initialize_auth,
+ Namespace,
+)
from sparkmagic.utils.sparkevents import SparkEvents
from sparkmagic.utils.constants import LANGS_SUPPORTED
-from sparkmagic.utils.dataframe_parser import cell_contains_dataframe, CellOutputHtmlParser
+from sparkmagic.utils.dataframe_parser import (
+ cell_contains_dataframe,
+ CellOutputHtmlParser,
+)
from sparkmagic.livyclientlib.command import Command
from sparkmagic.livyclientlib.endpoint import Endpoint
from sparkmagic.magics.sparkmagicsbase import SparkMagicBase, SparkOutputHandler
-from sparkmagic.livyclientlib.exceptions import handle_expected_exceptions, wrap_unexpected_exceptions, \
- BadUserDataException
+from sparkmagic.livyclientlib.exceptions import (
+ handle_expected_exceptions,
+ wrap_unexpected_exceptions,
+ BadUserDataException,
+)
def _event(f):
def wrapped(self, *args, **kwargs):
guid = self._generate_uuid()
- self._spark_events.emit_magic_execution_start_event(f.__name__, get_livy_kind(self.language), guid)
+ self._spark_events.emit_magic_execution_start_event(
+ f.__name__, get_livy_kind(self.language), guid
+ )
try:
result = f(self, *args, **kwargs)
except Exception as e:
- self._spark_events.emit_magic_execution_end_event(f.__name__, get_livy_kind(self.language), guid,
- False, e.__class__.__name__, str(e))
+ self._spark_events.emit_magic_execution_end_event(
+ f.__name__,
+ get_livy_kind(self.language),
+ guid,
+ False,
+ e.__class__.__name__,
+ str(e),
+ )
raise
else:
- self._spark_events.emit_magic_execution_end_event(f.__name__, get_livy_kind(self.language), guid,
- True, u"", u"")
+ self._spark_events.emit_magic_execution_end_event(
+ f.__name__, get_livy_kind(self.language), guid, True, "", ""
+ )
return result
+
wrapped.__name__ = f.__name__
wrapped.__doc__ = f.__doc__
return wrapped
@@ -51,15 +72,15 @@ def __init__(self, shell, data=None, spark_events=None):
# You must call the parent constructor
super(KernelMagics, self).__init__(shell, data)
- self.session_name = u"session_name"
+ self.session_name = "session_name"
self.session_started = False
# In order to set these following 3 properties, call %%_do_not_call_change_language -l language
- self.language = u""
+ self.language = ""
self.endpoint = None
self.fatal_error = False
self.allow_retry_fatal = False
- self.fatal_error_message = u""
+ self.fatal_error_message = ""
if spark_events is None:
spark_events = SparkEvents()
self._spark_events = spark_events
@@ -72,7 +93,7 @@ def __init__(self, shell, data=None, spark_events=None):
def help(self, line, cell="", local_ns=None):
parse_argstring_or_throw(self.help, line)
self._assure_cell_body_is_empty(KernelMagics.help.__name__, cell)
- help_html = u"""
+ help_html = """
Magic |
@@ -165,23 +186,49 @@ def help(self, line, cell="", local_ns=None):
self.ipython_display.html(help_html)
@cell_magic
- def local(self, line, cell=u"", local_ns=None):
+ def local(self, line, cell="", local_ns=None):
# This should not be reachable thanks to UserCodeParser. Registering it here so that it auto-completes with tab.
- raise NotImplementedError(u"UserCodeParser should have prevented code execution from reaching here.")
+ raise NotImplementedError(
+ "UserCodeParser should have prevented code execution from reaching here."
+ )
@magic_arguments()
- @argument("-i", "--input", type=str, default=None, help="If present, indicated variable will be stored in variable"
- " in Spark's context.")
- @argument("-t", "--vartype", type=str, default='str', help="Optionally specify the type of input variable. "
- "Available: 'str' - string(default) or 'df' - Pandas DataFrame")
- @argument("-n", "--varname", type=str, default=None, help="Optionally specify the custom name for the input variable.")
- @argument("-m", "--maxrows", type=int, default=2500, help="Maximum number of rows that will be pulled back "
- "from the local dataframe")
+ @argument(
+ "-i",
+ "--input",
+ type=str,
+ default=None,
+ help="If present, indicated variable will be stored in variable"
+ " in Spark's context.",
+ )
+ @argument(
+ "-t",
+ "--vartype",
+ type=str,
+ default="str",
+ help="Optionally specify the type of input variable. "
+ "Available: 'str' - string(default) or 'df' - Pandas DataFrame",
+ )
+ @argument(
+ "-n",
+ "--varname",
+ type=str,
+ default=None,
+ help="Optionally specify the custom name for the input variable.",
+ )
+ @argument(
+ "-m",
+ "--maxrows",
+ type=int,
+ default=2500,
+ help="Maximum number of rows that will be pulled back "
+ "from the local dataframe",
+ )
@cell_magic
@needs_local_scope
@wrap_unexpected_exceptions
@handle_expected_exceptions
- def send_to_spark(self, line, cell=u"", local_ns=None):
+ def send_to_spark(self, line, cell="", local_ns=None):
self._assure_cell_body_is_empty(KernelMagics.send_to_spark.__name__, cell)
args = parse_argstring_or_throw(self.send_to_spark, line)
@@ -189,7 +236,9 @@ def send_to_spark(self, line, cell=u"", local_ns=None):
raise BadUserDataException("-i param not provided.")
if self._do_not_call_start_session(""):
- self.do_send_to_spark(cell, args.input, args.vartype, args.varname, args.maxrows, None)
+ self.do_send_to_spark(
+ cell, args.input, args.vartype, args.varname, args.maxrows, None
+ )
else:
return
@@ -198,15 +247,21 @@ def send_to_spark(self, line, cell=u"", local_ns=None):
@wrap_unexpected_exceptions
@handle_expected_exceptions
@_event
- def info(self, line, cell=u"", local_ns=None):
+ def info(self, line, cell="", local_ns=None):
parse_argstring_or_throw(self.info, line)
self._assure_cell_body_is_empty(KernelMagics.info.__name__, cell)
if self.session_started:
- current_session_id = self.spark_controller.get_session_id_for_client(self.session_name)
+ current_session_id = self.spark_controller.get_session_id_for_client(
+ self.session_name
+ )
else:
current_session_id = None
- self.ipython_display.html(u"Current session configs: {}
".format(conf.get_session_properties(self.language)))
+ self.ipython_display.html(
+ "Current session configs: {}
".format(
+ conf.get_session_properties(self.language)
+ )
+ )
info_sessions = self.spark_controller.get_all_sessions_endpoint(self.endpoint)
self._print_endpoint_info(info_sessions, current_session_id)
@@ -223,11 +278,19 @@ def logs(self, line, cell="", local_ns=None):
out = self.spark_controller.get_logs()
self.ipython_display.write(out)
else:
- self.ipython_display.write(u"No logs yet.")
+ self.ipython_display.write("No logs yet.")
@magic_arguments()
@cell_magic
- @argument("-f", "--force", type=bool, default=False, nargs="?", const=True, help="If present, user understands.")
+ @argument(
+ "-f",
+ "--force",
+ type=bool,
+ default=False,
+ nargs="?",
+ const=True,
+ help="If present, user understands.",
+ )
@wrap_unexpected_exceptions
@handle_expected_exceptions
@_event
@@ -235,44 +298,86 @@ def configure(self, line, cell="", local_ns=None):
try:
dictionary = json.loads(cell)
except ValueError:
- self.ipython_display.send_error(u"Could not parse JSON object from input '{}'".format(cell))
+ self.ipython_display.send_error(
+ "Could not parse JSON object from input '{}'".format(cell)
+ )
return
args = parse_argstring_or_throw(self.configure, line)
if self.session_started:
if not args.force:
- self.ipython_display.send_error(u"A session has already been started. If you intend to recreate the "
- u"session with new configurations, please include the -f argument.")
+ self.ipython_display.send_error(
+ "A session has already been started. If you intend to recreate the "
+ "session with new configurations, please include the -f argument."
+ )
return
else:
- self._do_not_call_delete_session(u"")
+ self._do_not_call_delete_session("")
self._override_session_settings(dictionary)
- self._do_not_call_start_session(u"")
+ self._do_not_call_start_session("")
else:
self._override_session_settings(dictionary)
- self.info(u"")
+ self.info("")
@magic_arguments()
@cell_magic
@needs_local_scope
- @argument("-o", "--output", type=str, default=None, help="If present, indicated variable will be stored in variable"
- "of this name in user's local context.")
- @argument("-m", "--samplemethod", type=str, default=None, help="Sample method for dataframe: either take or sample")
- @argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back "
- "from the dataframe on the server for storing")
- @argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from dataframe")
- @argument("-c", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) "
- "of the dataframe or not (pass False)")
+ @argument(
+ "-o",
+ "--output",
+ type=str,
+ default=None,
+ help="If present, indicated variable will be stored in variable"
+ "of this name in user's local context.",
+ )
+ @argument(
+ "-m",
+ "--samplemethod",
+ type=str,
+ default=None,
+ help="Sample method for dataframe: either take or sample",
+ )
+ @argument(
+ "-n",
+ "--maxrows",
+ type=int,
+ default=None,
+ help="Maximum number of rows that will be pulled back "
+ "from the dataframe on the server for storing",
+ )
+ @argument(
+ "-r",
+ "--samplefraction",
+ type=float,
+ default=None,
+ help="Sample fraction for sampling from dataframe",
+ )
+ @argument(
+ "-c",
+ "--coerce",
+ type=str,
+ default=None,
+ help="Whether to automatically coerce the types (default, pass True if being explicit) "
+ "of the dataframe or not (pass False)",
+ )
@wrap_unexpected_exceptions
@handle_expected_exceptions
def spark(self, line, cell="", local_ns=None):
- if not self._do_not_call_start_session(u""):
+ if not self._do_not_call_start_session(""):
return
args = parse_argstring_or_throw(self.spark, line)
coerce = get_coerce_value(args.coerce)
- self.execute_spark(cell, args.output, args.samplemethod, args.maxrows, args.samplefraction, None, coerce)
+ self.execute_spark(
+ cell,
+ args.output,
+ args.samplemethod,
+ args.maxrows,
+ args.samplefraction,
+ None,
+ coerce,
+ )
@cell_magic
@needs_local_scope
@@ -280,34 +385,72 @@ def spark(self, line, cell="", local_ns=None):
@handle_expected_exceptions
def pretty(self, line, cell="", local_ns=None):
"""Evaluates a cell and converts dataframes in cell output to HTML tables."""
- if not self._do_not_call_start_session(u""):
- return
+ if not self._do_not_call_start_session(""):
+ return
def pretty_output_handler(out):
if cell_contains_dataframe(out):
- self.ipython_display.html(CellOutputHtmlParser.to_html(out))
- else:
+ self.ipython_display.html(CellOutputHtmlParser.to_html(out))
+ else:
self.ipython_display.write(out)
- so = SparkOutputHandler(html=self.ipython_display.html,
- text=pretty_output_handler,
- default=self.ipython_display.display)
-
- self.execute_spark(cell, None, None, None, None, None, None, output_handler=so)
+ so = SparkOutputHandler(
+ html=self.ipython_display.html,
+ text=pretty_output_handler,
+ default=self.ipython_display.display,
+ )
+ self.execute_spark(cell, None, None, None, None, None, None, output_handler=so)
@magic_arguments()
@cell_magic
@needs_local_scope
- @argument("-o", "--output", type=str, default=None, help="If present, query will be stored in variable of this "
- "name.")
- @argument("-q", "--quiet", type=bool, default=False, const=True, nargs="?", help="Return None instead of the dataframe.")
- @argument("-m", "--samplemethod", type=str, default=None, help="Sample method for SQL queries: either take or sample")
- @argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back "
- "from the server for SQL queries")
- @argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from SQL queries")
- @argument("-c", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) "
- "of the dataframe or not (pass False)")
+ @argument(
+ "-o",
+ "--output",
+ type=str,
+ default=None,
+ help="If present, query will be stored in variable of this " "name.",
+ )
+ @argument(
+ "-q",
+ "--quiet",
+ type=bool,
+ default=False,
+ const=True,
+ nargs="?",
+ help="Return None instead of the dataframe.",
+ )
+ @argument(
+ "-m",
+ "--samplemethod",
+ type=str,
+ default=None,
+ help="Sample method for SQL queries: either take or sample",
+ )
+ @argument(
+ "-n",
+ "--maxrows",
+ type=int,
+ default=None,
+ help="Maximum number of rows that will be pulled back "
+ "from the server for SQL queries",
+ )
+ @argument(
+ "-r",
+ "--samplefraction",
+ type=float,
+ default=None,
+ help="Sample fraction for sampling from SQL queries",
+ )
+ @argument(
+ "-c",
+ "--coerce",
+ type=str,
+ default=None,
+ help="Whether to automatically coerce the types (default, pass True if being explicit) "
+ "of the dataframe or not (pass False)",
+ )
@wrap_unexpected_exceptions
@handle_expected_exceptions
def sql(self, line, cell="", local_ns=None):
@@ -318,12 +461,28 @@ def sql(self, line, cell="", local_ns=None):
coerce = get_coerce_value(args.coerce)
- return self.execute_sqlquery(cell, args.samplemethod, args.maxrows, args.samplefraction,
- None, args.output, args.quiet, coerce)
+ return self.execute_sqlquery(
+ cell,
+ args.samplemethod,
+ args.maxrows,
+ args.samplefraction,
+ None,
+ args.output,
+ args.quiet,
+ coerce,
+ )
@magic_arguments()
@cell_magic
- @argument("-f", "--force", type=bool, default=False, nargs="?", const=True, help="If present, user understands.")
+ @argument(
+ "-f",
+ "--force",
+ type=bool,
+ default=False,
+ nargs="?",
+ const=True,
+ help="If present, user understands.",
+ )
@wrap_unexpected_exceptions
@handle_expected_exceptions
@_event
@@ -331,18 +490,28 @@ def cleanup(self, line, cell="", local_ns=None):
self._assure_cell_body_is_empty(KernelMagics.cleanup.__name__, cell)
args = parse_argstring_or_throw(self.cleanup, line)
if args.force:
- self._do_not_call_delete_session(u"")
+ self._do_not_call_delete_session("")
self.spark_controller.cleanup_endpoint(self.endpoint)
else:
- self.ipython_display.send_error(u"When you clean up the endpoint, all sessions will be lost, including the "
- u"one used for this notebook. Include the -f parameter if that's your "
- u"intention.")
+ self.ipython_display.send_error(
+ "When you clean up the endpoint, all sessions will be lost, including the "
+ "one used for this notebook. Include the -f parameter if that's your "
+ "intention."
+ )
return
@magic_arguments()
@cell_magic
- @argument("-f", "--force", type=bool, default=False, nargs="?", const=True, help="If present, user understands.")
+ @argument(
+ "-f",
+ "--force",
+ type=bool,
+ default=False,
+ nargs="?",
+ const=True,
+ help="If present, user understands.",
+ )
@argument("-s", "--session", type=int, help="Session id number to delete.")
@wrap_unexpected_exceptions
@handle_expected_exceptions
@@ -353,21 +522,27 @@ def delete(self, line, cell="", local_ns=None):
session = args.session
if args.session is None:
- self.ipython_display.send_error(u'You must provide a session ID (-s argument).')
+ self.ipython_display.send_error(
+ "You must provide a session ID (-s argument)."
+ )
return
if args.force:
id = self.spark_controller.get_session_id_for_client(self.session_name)
if session == id:
- self.ipython_display.send_error(u"Cannot delete this kernel's session ({}). Specify a different session,"
- u" shutdown the kernel to delete this session, or run %cleanup to "
- u"delete all sessions for this endpoint.".format(id))
+ self.ipython_display.send_error(
+ "Cannot delete this kernel's session ({}). Specify a different session,"
+ " shutdown the kernel to delete this session, or run %cleanup to "
+ "delete all sessions for this endpoint.".format(id)
+ )
return
self.spark_controller.delete_session_by_id(self.endpoint, session)
else:
- self.ipython_display.send_error(u"Include the -f parameter if you understand that all statements executed "
- u"in this session will be lost.")
+ self.ipython_display.send_error(
+ "Include the -f parameter if you understand that all statements executed "
+ "in this session will be lost."
+ )
@cell_magic
def _do_not_call_start_session(self, line, cell="", local_ns=None):
@@ -385,14 +560,16 @@ def _do_not_call_start_session(self, line, cell="", local_ns=None):
properties = conf.get_session_properties(self.language)
try:
- self.spark_controller.add_session(self.session_name, self.endpoint, skip, properties)
+ self.spark_controller.add_session(
+ self.session_name, self.endpoint, skip, properties
+ )
self.session_started = True
self.fatal_error = False
- self.fatal_error_message = u""
+ self.fatal_error_message = ""
except Exception as e:
self.fatal_error = True
self.fatal_error_message = conf.fatal_error_suggestion().format(e)
- self.logger.error(u"Error creating session: {}".format(e))
+ self.logger.error("Error creating session: {}".format(e))
self.ipython_display.send_error(self.fatal_error_message)
if conf.all_errors_are_fatal():
@@ -427,11 +604,15 @@ def _do_not_call_change_language(self, line, cell="", local_ns=None):
language = args.language.lower()
if language not in LANGS_SUPPORTED:
- self.ipython_display.send_error(u"'{}' language not supported in kernel magics.".format(language))
+ self.ipython_display.send_error(
+ "'{}' language not supported in kernel magics.".format(language)
+ )
return
if self.session_started:
- self.ipython_display.send_error(u"Cannot change the language if a session has been started.")
+ self.ipython_display.send_error(
+ "Cannot change the language if a session has been started."
+ )
return
self.language = language
@@ -439,22 +620,24 @@ def _do_not_call_change_language(self, line, cell="", local_ns=None):
@magic_arguments()
@line_magic
- @argument("-u", "--username", dest='user', type=str, help="Username to use.")
+ @argument("-u", "--username", dest="user", type=str, help="Username to use.")
@argument("-p", "--password", type=str, help="Password to use.")
- @argument("-s", "--server", dest='url', type=str, help="Url of server to use.")
+ @argument("-s", "--server", dest="url", type=str, help="Url of server to use.")
@argument("-t", "--auth", type=str, help="Auth type for authentication")
@_event
def _do_not_call_change_endpoint(self, line, cell="", local_ns=None):
args = parse_argstring_or_throw(self._do_not_call_change_endpoint, line)
if self.session_started:
- error = u"Cannot change the endpoint if a session has been started."
+ error = "Cannot change the endpoint if a session has been started."
raise BadUserDataException(error)
auth = initialize_auth(args=args)
self.endpoint = Endpoint(args.url, auth)
@line_magic
def matplot(self, line, cell="", local_ns=None):
- session = self.spark_controller.get_session_by_name_or_default(self.session_name)
+ session = self.spark_controller.get_session_by_name_or_default(
+ self.session_name
+ )
command = Command("%matplot " + line)
(success, out, mimetype) = command.execute(session)
if success:
@@ -463,8 +646,13 @@ def matplot(self, line, cell="", local_ns=None):
session.ipython_display.send_error(out)
def refresh_configuration(self):
- credentials = getattr(conf, 'base64_kernel_' + self.language + '_credentials')()
- (username, password, auth, url) = (credentials['username'], credentials['password'], credentials['auth'], credentials['url'])
+ credentials = getattr(conf, "base64_kernel_" + self.language + "_credentials")()
+ (username, password, auth, url) = (
+ credentials["username"],
+ credentials["password"],
+ credentials["auth"],
+ credentials["url"],
+ )
args = Namespace(auth=auth, user=username, password=password, url=url)
auth_instance = initialize_auth(args)
self.endpoint = Endpoint(url, auth_instance)
@@ -492,8 +680,12 @@ def _generate_uuid():
@staticmethod
def _assure_cell_body_is_empty(magic_name, cell):
if cell.strip():
- raise BadUserDataException(u"Cell body for %%{} magic must be empty; got '{}' instead"
- .format(magic_name, cell.strip()))
+ raise BadUserDataException(
+ "Cell body for %%{} magic must be empty; got '{}' instead".format(
+ magic_name, cell.strip()
+ )
+ )
+
def load_ipython_extension(ip):
ip.register_magics(KernelMagics)
diff --git a/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py b/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py
index 99c4176c3..f102a9cad 100644
--- a/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py
+++ b/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py
@@ -1 +1 @@
-__version__ = '0.0.1'
\ No newline at end of file
+__version__ = "0.0.1"
diff --git a/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py b/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py
index 7cbb6364e..a40c7e488 100644
--- a/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py
+++ b/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py
@@ -6,29 +6,32 @@
class PySparkKernel(SparkKernelBase):
def __init__(self, **kwargs):
- implementation = 'PySpark'
- implementation_version = '1.0'
+ implementation = "PySpark"
+ implementation_version = "1.0"
language = LANG_PYTHON
- language_version = '0.1'
+ language_version = "0.1"
language_info = {
- 'name': 'pyspark',
- 'mimetype': 'text/x-python',
- 'codemirror_mode': {
- 'name': 'python',
- 'version': 3
- },
- 'file_extension': '.py',
- 'pygments_lexer': 'python3'
+ "name": "pyspark",
+ "mimetype": "text/x-python",
+ "codemirror_mode": {"name": "python", "version": 3},
+ "file_extension": ".py",
+ "pygments_lexer": "python3",
}
session_language = LANG_PYTHON
- super(PySparkKernel,
- self).__init__(implementation, implementation_version, language,
- language_version, language_info, session_language,
- **kwargs)
+ super(PySparkKernel, self).__init__(
+ implementation,
+ implementation_version,
+ language,
+ language_version,
+ language_info,
+ session_language,
+ **kwargs
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
from ipykernel.kernelapp import IPKernelApp
+
IPKernelApp.launch_instance(kernel_class=PySparkKernel)
diff --git a/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py b/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py
index 99c4176c3..f102a9cad 100644
--- a/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py
+++ b/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py
@@ -1 +1 @@
-__version__ = '0.0.1'
\ No newline at end of file
+__version__ = "0.0.1"
diff --git a/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py b/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py
index 365eafe96..22da09b4d 100644
--- a/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py
+++ b/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py
@@ -6,26 +6,32 @@
class SparkKernel(SparkKernelBase):
def __init__(self, **kwargs):
- implementation = 'Spark'
- implementation_version = '1.0'
+ implementation = "Spark"
+ implementation_version = "1.0"
language = LANG_SCALA
- language_version = '0.1'
+ language_version = "0.1"
language_info = {
- 'name': 'scala',
- 'mimetype': 'text/x-scala',
- 'codemirror_mode': 'text/x-scala',
- 'file_extension': '.sc',
- 'pygments_lexer': 'scala'
+ "name": "scala",
+ "mimetype": "text/x-scala",
+ "codemirror_mode": "text/x-scala",
+ "file_extension": ".sc",
+ "pygments_lexer": "scala",
}
session_language = LANG_SCALA
- super(SparkKernel,
- self).__init__(implementation, implementation_version, language,
- language_version, language_info, session_language,
- **kwargs)
+ super(SparkKernel, self).__init__(
+ implementation,
+ implementation_version,
+ language,
+ language_version,
+ language_info,
+ session_language,
+ **kwargs
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
from ipykernel.kernelapp import IPKernelApp
+
IPKernelApp.launch_instance(kernel_class=SparkKernel)
diff --git a/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py b/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py
index 99c4176c3..f102a9cad 100644
--- a/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py
+++ b/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py
@@ -1 +1 @@
-__version__ = '0.0.1'
\ No newline at end of file
+__version__ = "0.0.1"
diff --git a/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py b/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py
index c55567e65..85c95e790 100644
--- a/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py
+++ b/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py
@@ -6,26 +6,32 @@
class SparkRKernel(SparkKernelBase):
def __init__(self, **kwargs):
- implementation = 'SparkR'
- implementation_version = '1.0'
+ implementation = "SparkR"
+ implementation_version = "1.0"
language = LANG_R
- language_version = '0.1'
+ language_version = "0.1"
language_info = {
- 'name': 'sparkR',
- 'mimetype': 'text/x-rsrc',
- 'codemirror_mode': 'text/x-rsrc',
- 'file_extension': '.r',
- 'pygments_lexer': 'r'
+ "name": "sparkR",
+ "mimetype": "text/x-rsrc",
+ "codemirror_mode": "text/x-rsrc",
+ "file_extension": ".r",
+ "pygments_lexer": "r",
}
session_language = LANG_R
- super(SparkRKernel,
- self).__init__(implementation, implementation_version, language,
- language_version, language_info, session_language,
- **kwargs)
+ super(SparkRKernel, self).__init__(
+ implementation,
+ implementation_version,
+ language,
+ language_version,
+ language_info,
+ session_language,
+ **kwargs
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
from ipykernel.kernelapp import IPKernelApp
+
IPKernelApp.launch_instance(kernel_class=SparkRKernel)
diff --git a/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py b/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py
index 55c80e475..bd45b2efa 100644
--- a/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py
+++ b/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py
@@ -3,9 +3,11 @@
try:
from asyncio import Future
except ImportError:
+
class Future(object):
"""A class nothing will use."""
+
import requests
from ipykernel.ipkernel import IPythonKernel
from hdijupyterutils.ipythondisplay import IpythonDisplay
@@ -17,8 +19,17 @@ class Future(object):
class SparkKernelBase(IPythonKernel):
- def __init__(self, implementation, implementation_version, language, language_version, language_info,
- session_language, user_code_parser=None, **kwargs):
+ def __init__(
+ self,
+ implementation,
+ implementation_version,
+ language,
+ language_version,
+ language_info,
+ session_language,
+ user_code_parser=None,
+ **kwargs
+ ):
# Required by Jupyter - Override
self.implementation = implementation
self.implementation_version = implementation_version
@@ -31,7 +42,7 @@ def __init__(self, implementation, implementation_version, language, language_ve
super(SparkKernelBase, self).__init__(**kwargs)
- self.logger = SparkLog(u"{}_jupyter_kernel".format(self.session_language))
+ self.logger = SparkLog("{}_jupyter_kernel".format(self.session_language))
self._fatal_error = None
self.ipython_display = IpythonDisplay()
@@ -49,12 +60,17 @@ def __init__(self, implementation, implementation_version, language, language_ve
if conf.use_auto_viz():
self._register_auto_viz()
- def do_execute(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False):
+ def do_execute(
+ self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
+ ):
def f(self):
if self._fatal_error is not None:
return self._repeat_fatal_error()
- return self._do_execute(code, silent, store_history, user_expressions, allow_stdin)
+ return self._do_execute(
+ code, silent, store_history, user_expressions, allow_stdin
+ )
+
return wrap_unexpected_exceptions(f, self._complete_cell)(self)
def do_shutdown(self, restart):
@@ -66,54 +82,91 @@ def do_shutdown(self, restart):
def _do_execute(self, code, silent, store_history, user_expressions, allow_stdin):
code_to_run = self.user_code_parser.get_code_to_run(code)
- res = self._execute_cell(code_to_run, silent, store_history, user_expressions, allow_stdin)
+ res = self._execute_cell(
+ code_to_run, silent, store_history, user_expressions, allow_stdin
+ )
return res
def _load_magics_extension(self):
register_magics_code = "%load_ext sparkmagic.kernels"
- self._execute_cell(register_magics_code, True, False, shutdown_if_error=True,
- log_if_error="Failed to load the Spark kernels magics library.")
+ self._execute_cell(
+ register_magics_code,
+ True,
+ False,
+ shutdown_if_error=True,
+ log_if_error="Failed to load the Spark kernels magics library.",
+ )
self.logger.debug("Loaded magics.")
def _change_language(self):
- register_magics_code = "%%_do_not_call_change_language -l {}\n ".format(self.session_language)
- self._execute_cell(register_magics_code, True, False, shutdown_if_error=True,
- log_if_error="Failed to change language to {}.".format(self.session_language))
+ register_magics_code = "%%_do_not_call_change_language -l {}\n ".format(
+ self.session_language
+ )
+ self._execute_cell(
+ register_magics_code,
+ True,
+ False,
+ shutdown_if_error=True,
+ log_if_error="Failed to change language to {}.".format(
+ self.session_language
+ ),
+ )
self.logger.debug("Changed language.")
def _register_auto_viz(self):
from sparkmagic.utils.sparkevents import get_spark_events_handler
import autovizwidget.utils.configuration as c
-
+
handler = get_spark_events_handler()
c.override("events_handler", handler)
-
+
register_auto_viz_code = """from autovizwidget.widget.utils import display_dataframe
ip = get_ipython()
ip.display_formatter.ipython_display_formatter.for_type_by_name('pandas.core.frame', 'DataFrame', display_dataframe)"""
- self._execute_cell(register_auto_viz_code, True, False, shutdown_if_error=True,
- log_if_error="Failed to register auto viz for notebook.")
+ self._execute_cell(
+ register_auto_viz_code,
+ True,
+ False,
+ shutdown_if_error=True,
+ log_if_error="Failed to register auto viz for notebook.",
+ )
self.logger.debug("Registered auto viz.")
def _delete_session(self):
code = "%%_do_not_call_delete_session\n "
self._execute_cell_for_user(code, True, False)
- def _execute_cell(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False,
- shutdown_if_error=False, log_if_error=None):
- reply_content = self._execute_cell_for_user(code, silent, store_history, user_expressions, allow_stdin)
-
- if shutdown_if_error and reply_content[u"status"] == u"error":
- error_from_reply = reply_content[u"evalue"]
+ def _execute_cell(
+ self,
+ code,
+ silent,
+ store_history=True,
+ user_expressions=None,
+ allow_stdin=False,
+ shutdown_if_error=False,
+ log_if_error=None,
+ ):
+ reply_content = self._execute_cell_for_user(
+ code, silent, store_history, user_expressions, allow_stdin
+ )
+
+ if shutdown_if_error and reply_content["status"] == "error":
+ error_from_reply = reply_content["evalue"]
if log_if_error is not None:
- message = "{}\nException details:\n\t\"{}\"".format(log_if_error, error_from_reply)
+ message = '{}\nException details:\n\t"{}"'.format(
+ log_if_error, error_from_reply
+ )
return self._abort_with_fatal_error(message)
return reply_content
- def _execute_cell_for_user(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False):
- result = super(SparkKernelBase, self).do_execute(code, silent, store_history, user_expressions, allow_stdin)
+ def _execute_cell_for_user(
+ self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
+ ):
+ result = super(SparkKernelBase, self).do_execute(
+ code, silent, store_history, user_expressions, allow_stdin
+ )
if isinstance(result, Future):
result = result.result()
return result
diff --git a/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py b/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py
index 1b3f1b4fb..47efbd994 100644
--- a/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py
+++ b/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py
@@ -9,9 +9,18 @@ class UserCodeParser(object):
# For example, the %%info magic has no cell body input, i.e. it is incorrect to call
# %%info
# some_input
- _magics_with_no_cell_body = [i.__name__ for i in [KernelMagics.info, KernelMagics.logs, KernelMagics.cleanup,
- KernelMagics.delete, KernelMagics.help, KernelMagics.spark,
- KernelMagics.send_to_spark]]
+ _magics_with_no_cell_body = [
+ i.__name__
+ for i in [
+ KernelMagics.info,
+ KernelMagics.logs,
+ KernelMagics.cleanup,
+ KernelMagics.delete,
+ KernelMagics.help,
+ KernelMagics.spark,
+ KernelMagics.send_to_spark,
+ ]
+ ]
def get_code_to_run(self, code):
try:
@@ -22,9 +31,9 @@ def get_code_to_run(self, code):
if code.startswith("%%local") or code.startswith("%local"):
return all_but_first_line
elif any(code.startswith("%%" + s) for s in self._magics_with_no_cell_body):
- return u"{}\n ".format(code)
+ return "{}\n ".format(code)
elif any(code.startswith("%" + s) for s in self._magics_with_no_cell_body):
- return u"%{}\n ".format(code)
+ return "%{}\n ".format(code)
elif code.startswith("%%") or code.startswith("%"):
# If they use other line magics:
# %autosave
@@ -34,4 +43,4 @@ def get_code_to_run(self, code):
elif not code:
return code
else:
- return u"%%spark\n{}".format(code)
+ return "%%spark\n{}".format(code)
diff --git a/sparkmagic/sparkmagic/livyclientlib/command.py b/sparkmagic/sparkmagic/livyclientlib/command.py
index 309a6158a..3a03266f3 100644
--- a/sparkmagic/sparkmagic/livyclientlib/command.py
+++ b/sparkmagic/sparkmagic/livyclientlib/command.py
@@ -11,18 +11,27 @@
import sparkmagic.utils.configuration as conf
from sparkmagic.utils.sparklogger import SparkLog
from sparkmagic.utils.sparkevents import SparkEvents
-from sparkmagic.utils.constants import MAGICS_LOGGER_NAME, FINAL_STATEMENT_STATUS, \
- MIMETYPE_IMAGE_PNG, MIMETYPE_TEXT_HTML, MIMETYPE_TEXT_PLAIN, \
- COMMAND_INTERRUPTED_MSG, COMMAND_CANCELLATION_FAILED_MSG
-from .exceptions import LivyUnexpectedStatusException, SparkStatementCancelledException, \
- SparkStatementCancellationFailedException
+from sparkmagic.utils.constants import (
+ MAGICS_LOGGER_NAME,
+ FINAL_STATEMENT_STATUS,
+ MIMETYPE_IMAGE_PNG,
+ MIMETYPE_TEXT_HTML,
+ MIMETYPE_TEXT_PLAIN,
+ COMMAND_INTERRUPTED_MSG,
+ COMMAND_CANCELLATION_FAILED_MSG,
+)
+from .exceptions import (
+ LivyUnexpectedStatusException,
+ SparkStatementCancelledException,
+ SparkStatementCancellationFailedException,
+)
class Command(ObjectWithGuid):
def __init__(self, code, spark_events=None):
super(Command, self).__init__()
self.code = textwrap.dedent(code)
- self.logger = SparkLog(u"Command")
+ self.logger = SparkLog("Command")
if spark_events is None:
spark_events = SparkEvents()
self._spark_events = spark_events
@@ -37,68 +46,99 @@ def __ne__(self, other):
return not self == other
def execute(self, session):
- self._spark_events.emit_statement_execution_start_event(session.guid, session.kind, session.id, self.guid)
+ self._spark_events.emit_statement_execution_start_event(
+ session.guid, session.kind, session.id, self.guid
+ )
statement_id = -1
try:
session.wait_for_idle()
- data = {u"code": self.code}
+ data = {"code": self.code}
response = session.http_client.post_statement(session.id, data)
- statement_id = response[u'id']
+ statement_id = response["id"]
output = self._get_statement_output(session, statement_id)
except KeyboardInterrupt as e:
- self._spark_events.emit_statement_execution_end_event(session.guid, session.kind, session.id,
- self.guid, statement_id, False, e.__class__.__name__,
- str(e))
+ self._spark_events.emit_statement_execution_end_event(
+ session.guid,
+ session.kind,
+ session.id,
+ self.guid,
+ statement_id,
+ False,
+ e.__class__.__name__,
+ str(e),
+ )
try:
if statement_id >= 0:
- response = session.http_client.cancel_statement(session.id, statement_id)
+ response = session.http_client.cancel_statement(
+ session.id, statement_id
+ )
session.wait_for_idle()
except:
- raise SparkStatementCancellationFailedException(COMMAND_CANCELLATION_FAILED_MSG)
+ raise SparkStatementCancellationFailedException(
+ COMMAND_CANCELLATION_FAILED_MSG
+ )
else:
raise SparkStatementCancelledException(COMMAND_INTERRUPTED_MSG)
except Exception as e:
- self._spark_events.emit_statement_execution_end_event(session.guid, session.kind, session.id,
- self.guid, statement_id, False, e.__class__.__name__,
- str(e))
+ self._spark_events.emit_statement_execution_end_event(
+ session.guid,
+ session.kind,
+ session.id,
+ self.guid,
+ statement_id,
+ False,
+ e.__class__.__name__,
+ str(e),
+ )
raise
else:
- self._spark_events.emit_statement_execution_end_event(session.guid, session.kind, session.id,
- self.guid, statement_id, True, "", "")
+ self._spark_events.emit_statement_execution_end_event(
+ session.guid,
+ session.kind,
+ session.id,
+ self.guid,
+ statement_id,
+ True,
+ "",
+ "",
+ )
return output
def _get_statement_output(self, session, statement_id):
retries = 1
- progress = FloatProgress(value=0.0,
- min=0,
- max=1.0,
- step=0.01,
- description='Progress:',
- bar_style='info',
- orientation='horizontal',
- layout=Layout(width='50%', height='25px')
- )
+ progress = FloatProgress(
+ value=0.0,
+ min=0,
+ max=1.0,
+ step=0.01,
+ description="Progress:",
+ bar_style="info",
+ orientation="horizontal",
+ layout=Layout(width="50%", height="25px"),
+ )
session.ipython_display.display(progress)
while True:
statement = session.http_client.get_statement(session.id, statement_id)
- status = statement[u"state"].lower()
+ status = statement["state"].lower()
- self.logger.debug(u"Status of statement {} is {}.".format(statement_id, status))
+ self.logger.debug(
+ "Status of statement {} is {}.".format(statement_id, status)
+ )
if status not in FINAL_STATEMENT_STATUS:
- progress.value = statement.get('progress', 0.0)
+ progress.value = statement.get("progress", 0.0)
session.sleep(retries)
retries += 1
else:
- statement_output = statement[u"output"]
+ statement_output = statement["output"]
progress.close()
if statement_output is None:
- return (True, u"", MIMETYPE_TEXT_PLAIN)
+ return (True, "", MIMETYPE_TEXT_PLAIN)
- if statement_output[u"status"] == u"ok":
- data = statement_output[u"data"]
+ if statement_output["status"] == "ok":
+ data = statement_output["data"]
if MIMETYPE_IMAGE_PNG in data:
image = Image(base64.b64decode(data[MIMETYPE_IMAGE_PNG]))
return (True, image, MIMETYPE_IMAGE_PNG)
@@ -106,10 +146,17 @@ def _get_statement_output(self, session, statement_id):
return (True, data[MIMETYPE_TEXT_HTML], MIMETYPE_TEXT_HTML)
else:
return (True, data[MIMETYPE_TEXT_PLAIN], MIMETYPE_TEXT_PLAIN)
- elif statement_output[u"status"] == u"error":
- return (False,
- statement_output[u"evalue"] + u"\n" + u"".join(statement_output[u"traceback"]),
- MIMETYPE_TEXT_PLAIN)
+ elif statement_output["status"] == "error":
+ return (
+ False,
+ statement_output["evalue"]
+ + "\n"
+ + "".join(statement_output["traceback"]),
+ MIMETYPE_TEXT_PLAIN,
+ )
else:
- raise LivyUnexpectedStatusException(u"Unknown output status from Livy: '{}'"
- .format(statement_output[u"status"]))
+ raise LivyUnexpectedStatusException(
+ "Unknown output status from Livy: '{}'".format(
+ statement_output["status"]
+ )
+ )
diff --git a/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py b/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py
index 8ed854744..f4de99b81 100644
--- a/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py
+++ b/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py
@@ -18,7 +18,9 @@ def __init__(self, retry_seconds_to_sleep_list, max_retries):
if len(retry_seconds_to_sleep_list) == 0:
retry_seconds_to_sleep_list = [5]
elif not all(n > 0 for n in retry_seconds_to_sleep_list):
- raise BadUserConfigurationException(u"All items in the list in your config need to be positive for configurable retry policy")
+ raise BadUserConfigurationException(
+ "All items in the list in your config need to be positive for configurable retry policy"
+ )
self.retry_seconds_to_sleep_list = retry_seconds_to_sleep_list
self._max_index = len(self.retry_seconds_to_sleep_list) - 1
diff --git a/sparkmagic/sparkmagic/livyclientlib/endpoint.py b/sparkmagic/sparkmagic/livyclientlib/endpoint.py
index d49744382..59b32d36b 100644
--- a/sparkmagic/sparkmagic/livyclientlib/endpoint.py
+++ b/sparkmagic/sparkmagic/livyclientlib/endpoint.py
@@ -1,11 +1,12 @@
from .exceptions import BadUserDataException
+
class Endpoint(object):
def __init__(self, url, auth, implicitly_added=False):
if not url:
- raise BadUserDataException(u"URL must not be empty")
+ raise BadUserDataException("URL must not be empty")
- self.url = url.rstrip(u"/")
+ self.url = url.rstrip("/")
self.auth = auth
# implicitly_added is set to True only if the endpoint wasn't configured manually by the user through
# a widget, but was instead implicitly defined as an endpoint to a wrapper kernel in the configuration
@@ -24,4 +25,4 @@ def __ne__(self, other):
return not self == other
def __str__(self):
- return u"Endpoint({})".format(self.url)
+ return "Endpoint({})".format(self.url)
diff --git a/sparkmagic/sparkmagic/livyclientlib/exceptions.py b/sparkmagic/sparkmagic/livyclientlib/exceptions.py
index 5c3bb0e5f..b048aa49f 100644
--- a/sparkmagic/sparkmagic/livyclientlib/exceptions.py
+++ b/sparkmagic/sparkmagic/livyclientlib/exceptions.py
@@ -24,6 +24,7 @@ class HttpClientException(LivyClientLibException):
class LivyClientTimeoutException(LivyClientLibException):
"""An exception for timeouts while interacting with Livy."""
+
class DataFrameParseException(LivyClientLibException):
"""An internal error which suggests a bad implementation of dataframe parsing from JSON --
if we get a JSON parsing error when parsing the results from the Livy server, this exception
@@ -88,7 +89,17 @@ class SparkStatementCancellationFailedException(KeyboardInterrupt):
# == DECORATORS FOR EXCEPTION HANDLING ==
-EXPECTED_EXCEPTIONS = [BadUserConfigurationException, BadUserDataException, LivyUnexpectedStatusException, SqlContextNotFoundException, HttpClientException, LivyClientTimeoutException, SessionManagementException, SparkStatementException]
+EXPECTED_EXCEPTIONS = [
+ BadUserConfigurationException,
+ BadUserDataException,
+ LivyUnexpectedStatusException,
+ SqlContextNotFoundException,
+ HttpClientException,
+ LivyClientTimeoutException,
+ SessionManagementException,
+ SparkStatementException,
+]
+
def handle_expected_exceptions(f):
"""A decorator that handles expected exceptions. Self can be any object with
@@ -98,6 +109,7 @@ def handle_expected_exceptions(f):
def fn(self, ...):
etc..."""
from sparkmagic.utils import configuration as conf
+
exceptions_to_handle = tuple(EXPECTED_EXCEPTIONS)
# Notice that we're NOT handling e.DataFrameParseException here. That's because DataFrameParseException
@@ -114,6 +126,7 @@ def wrapped(self, *args, **kwargs):
return None
else:
return out
+
wrapped.__name__ = f.__name__
wrapped.__doc__ = f.__doc__
return wrapped
@@ -128,10 +141,15 @@ def wrap_unexpected_exceptions(f, execute_if_error=None):
Usage:
@wrap_unexpected_exceptions
def fn(self, ...):
- ..etc """
+ ..etc"""
from sparkmagic.utils import configuration as conf
+
def handle_exception(self, e):
- self.logger.error(u"ENCOUNTERED AN INTERNAL ERROR: {}\n\tTraceback:\n{}".format(e, traceback.format_exc()))
+ self.logger.error(
+ "ENCOUNTERED AN INTERNAL ERROR: {}\n\tTraceback:\n{}".format(
+ e, traceback.format_exc()
+ )
+ )
self.ipython_display.send_error(INTERNAL_ERROR_MSG.format(e))
return None if execute_if_error is None else execute_if_error()
@@ -144,6 +162,7 @@ def wrapped(self, *args, **kwargs):
return handle_exception(self, err)
else:
return out
+
wrapped.__name__ = f.__name__
wrapped.__doc__ = f.__doc__
return wrapped
diff --git a/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py b/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py
index 176ec7b79..fa54f154e 100644
--- a/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py
+++ b/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py
@@ -12,22 +12,29 @@
class LivyReliableHttpClient(object):
"""A Livy-specific Http client which wraps the normal ReliableHttpClient. Propagates
HttpClientExceptions up."""
+
def __init__(self, http_client, endpoint):
self.endpoint = endpoint
self._http_client = http_client
@staticmethod
def from_endpoint(endpoint):
- headers = {"Content-Type": "application/json" }
+ headers = {"Content-Type": "application/json"}
headers.update(conf.custom_headers())
retry_policy = LivyReliableHttpClient._get_retry_policy()
- return LivyReliableHttpClient(ReliableHttpClient(endpoint, headers, retry_policy), endpoint)
+ return LivyReliableHttpClient(
+ ReliableHttpClient(endpoint, headers, retry_policy), endpoint
+ )
def post_statement(self, session_id, data):
- return self._http_client.post(self._statements_url(session_id), [201], data).json()
+ return self._http_client.post(
+ self._statements_url(session_id), [201], data
+ ).json()
def get_statement(self, session_id, statement_id):
- return self._http_client.get(self._statement_url(session_id, statement_id), [200]).json()
+ return self._http_client.get(
+ self._statement_url(session_id, statement_id), [200]
+ ).json()
def get_sessions(self):
return self._http_client.get("/sessions", [200]).json()
@@ -42,13 +49,17 @@ def delete_session(self, session_id):
self._http_client.delete(self._session_url(session_id), [200, 404])
def get_all_session_logs(self, session_id):
- return self._http_client.get(self._session_url(session_id) + "/log?from=0", [200]).json()
+ return self._http_client.get(
+ self._session_url(session_id) + "/log?from=0", [200]
+ ).json()
def get_headers(self):
return self._http_client.get_headers()
def cancel_statement(self, session_id, statement_id):
- return self._http_client.post("{}/cancel".format(self._statement_url(session_id, statement_id)), [200], {}).json()
+ return self._http_client.post(
+ "{}/cancel".format(self._statement_url(session_id, statement_id)), [200], {}
+ ).json()
@staticmethod
def _session_url(session_id):
@@ -69,6 +80,11 @@ def _get_retry_policy():
if policy == LINEAR_RETRY:
return LinearRetryPolicy(seconds_to_sleep=5, max_retries=5)
elif policy == CONFIGURABLE_RETRY:
- return ConfigurableRetryPolicy(retry_seconds_to_sleep_list=conf.retry_seconds_to_sleep_list(), max_retries=conf.configurable_retry_policy_max_retries())
+ return ConfigurableRetryPolicy(
+ retry_seconds_to_sleep_list=conf.retry_seconds_to_sleep_list(),
+ max_retries=conf.configurable_retry_policy_max_retries(),
+ )
else:
- raise BadUserConfigurationException(u"Retry policy '{}' not supported".format(policy))
+ raise BadUserConfigurationException(
+ "Retry policy '{}' not supported".format(policy)
+ )
diff --git a/sparkmagic/sparkmagic/livyclientlib/livysession.py b/sparkmagic/sparkmagic/livyclientlib/livysession.py
index 81512d856..58643034c 100644
--- a/sparkmagic/sparkmagic/livyclientlib/livysession.py
+++ b/sparkmagic/sparkmagic/livyclientlib/livysession.py
@@ -11,8 +11,12 @@
from sparkmagic.utils.utils import get_sessions_info_html
from .configurableretrypolicy import ConfigurableRetryPolicy
from .command import Command
-from .exceptions import LivyClientTimeoutException, \
- LivyUnexpectedStatusException, BadUserDataException, SqlContextNotFoundException
+from .exceptions import (
+ LivyClientTimeoutException,
+ LivyUnexpectedStatusException,
+ BadUserDataException,
+ SqlContextNotFoundException,
+)
class _HeartbeatThread(threading.Thread):
@@ -41,10 +45,12 @@ def __init__(self, livy_session, refresh_seconds, retry_seconds, run_at_most=Non
def run(self):
loop_counter = 0
if self.livy_session is None:
- print(u"Will not start heartbeat thread because self.livy_session is None")
+ print("Will not start heartbeat thread because self.livy_session is None")
return
- self.livy_session.logger.info(u'Starting heartbeat for session {}'.format(self.livy_session.id))
+ self.livy_session.logger.info(
+ "Starting heartbeat for session {}".format(self.livy_session.id)
+ )
while self.livy_session is not None and loop_counter < self.run_at_most:
loop_counter += 1
@@ -58,23 +64,31 @@ def run(self):
# the "exception" function in the SparkLog class then you could just make this
# self.livy_session.logger.exception("some useful message") and it'll print
# out the stack trace too.
- self.livy_session.logger.error(u'{}'.format(e))
+ self.livy_session.logger.error("{}".format(e))
sleep(sleep_time)
-
def stop(self):
if self.livy_session is not None:
- self.livy_session.logger.info(u'Stopping heartbeat for session {}'.format(self.livy_session.id))
+ self.livy_session.logger.info(
+ "Stopping heartbeat for session {}".format(self.livy_session.id)
+ )
self.livy_session = None
self.join()
class LivySession(ObjectWithGuid):
- def __init__(self, http_client, properties, ipython_display,
- session_id=-1, spark_events=None,
- heartbeat_timeout=0, heartbeat_thread=None):
+ def __init__(
+ self,
+ http_client,
+ properties,
+ ipython_display,
+ session_id=-1,
+ spark_events=None,
+ heartbeat_timeout=0,
+ heartbeat_thread=None,
+ ):
super(LivySession, self).__init__()
assert constants.LIVY_KIND_PARAM in list(properties.keys())
kind = properties[constants.LIVY_KIND_PARAM]
@@ -95,28 +109,33 @@ def __init__(self, http_client, properties, ipython_display,
spark_events = SparkEvents()
self._spark_events = spark_events
- self._policy = ConfigurableRetryPolicy(retry_seconds_to_sleep_list=[0.2, 0.5, 0.5, 1, 1, 2], max_retries=5000)
+ self._policy = ConfigurableRetryPolicy(
+ retry_seconds_to_sleep_list=[0.2, 0.5, 0.5, 1, 1, 2], max_retries=5000
+ )
wait_for_idle_timeout_seconds = conf.wait_for_idle_timeout_seconds()
assert wait_for_idle_timeout_seconds > 0
- self.logger = SparkLog(u"LivySession")
+ self.logger = SparkLog("LivySession")
kind = kind.lower()
if kind not in constants.SESSION_KINDS_SUPPORTED:
- raise BadUserDataException(u"Session of kind '{}' not supported. Session must be of kinds {}."
- .format(kind, ", ".join(constants.SESSION_KINDS_SUPPORTED)))
+ raise BadUserDataException(
+ "Session of kind '{}' not supported. Session must be of kinds {}.".format(
+ kind, ", ".join(constants.SESSION_KINDS_SUPPORTED)
+ )
+ )
self._app_id = None
self._user = None
- self._logs = u""
+ self._logs = ""
self._http_client = http_client
self._wait_for_idle_timeout_seconds = wait_for_idle_timeout_seconds
self._printed_resource_warning = False
self.kind = kind
self.id = session_id
- self.session_info = u""
+ self.session_info = ""
self._heartbeat_thread = None
if session_id == -1:
@@ -126,8 +145,14 @@ def __init__(self, http_client, properties, ipython_display,
self._start_heartbeat_thread()
def __str__(self):
- return u"Session id: {}\tYARN id: {}\tKind: {}\tState: {}\n\tSpark UI: {}\n\tDriver Log: {}"\
- .format(self.id, self.get_app_id(), self.kind, self.status, self.get_spark_ui_url(), self.get_driver_log_url())
+ return "Session id: {}\tYARN id: {}\tKind: {}\tState: {}\n\tSpark UI: {}\n\tDriver Log: {}".format(
+ self.id,
+ self.get_app_id(),
+ self.kind,
+ self.status,
+ self.get_spark_ui_url(),
+ self.get_driver_log_url(),
+ )
def start(self):
"""Start the session against actual livy server."""
@@ -136,10 +161,10 @@ def start(self):
try:
r = self._http_client.post_session(self.properties)
- self.id = r[u"id"]
- self.status = str(r[u"state"])
+ self.id = r["id"]
+ self.status = str(r["state"])
- self.ipython_display.writeln(u"Starting Spark application")
+ self.ipython_display.writeln("Starting Spark application")
# Start heartbeat thread to keep Livy interactive session alive.
self._start_heartbeat_thread()
@@ -148,8 +173,11 @@ def start(self):
try:
self.wait_for_idle(conf.livy_session_startup_timeout_seconds())
except LivyClientTimeoutException:
- raise LivyClientTimeoutException(u"Session {} did not start up in {} seconds."
- .format(self.id, conf.livy_session_startup_timeout_seconds()))
+ raise LivyClientTimeoutException(
+ "Session {} did not start up in {} seconds.".format(
+ self.id, conf.livy_session_startup_timeout_seconds()
+ )
+ )
html = get_sessions_info_html([self], self.id)
self.ipython_display.html(html)
@@ -158,26 +186,41 @@ def start(self):
(success, out, mimetype) = command.execute(self)
if success:
- self.ipython_display.writeln(u"SparkSession available as 'spark'.")
+ self.ipython_display.writeln("SparkSession available as 'spark'.")
self.sql_context_variable_name = "spark"
else:
command = Command("sqlContext")
(success, out, mimetype) = command.execute(self)
if success:
- self.ipython_display.writeln(u"SparkContext available as 'sc'.")
- if ("hive" in out.lower()):
- self.ipython_display.writeln(u"HiveContext available as 'sqlContext'.")
+ self.ipython_display.writeln("SparkContext available as 'sc'.")
+ if "hive" in out.lower():
+ self.ipython_display.writeln(
+ "HiveContext available as 'sqlContext'."
+ )
else:
- self.ipython_display.writeln(u"SqlContext available as 'sqlContext'.")
+ self.ipython_display.writeln(
+ "SqlContext available as 'sqlContext'."
+ )
self.sql_context_variable_name = "sqlContext"
else:
- raise SqlContextNotFoundException(u"Neither SparkSession nor HiveContext/SqlContext is available.")
+ raise SqlContextNotFoundException(
+ "Neither SparkSession nor HiveContext/SqlContext is available."
+ )
except Exception as e:
- self._spark_events.emit_session_creation_end_event(self.guid, self.kind, self.id, self.status,
- False, e.__class__.__name__, str(e))
+ self._spark_events.emit_session_creation_end_event(
+ self.guid,
+ self.kind,
+ self.id,
+ self.status,
+ False,
+ e.__class__.__name__,
+ str(e),
+ )
raise
else:
- self._spark_events.emit_session_creation_end_event(self.guid, self.kind, self.id, self.status, True, "", "")
+ self._spark_events.emit_session_creation_end_event(
+ self.guid, self.kind, self.id, self.status, True, "", ""
+ )
def get_app_id(self):
if self._app_id is None:
@@ -195,7 +238,7 @@ def get_driver_log_url(self):
return self.get_app_info_member("driverLogUrl")
def get_logs(self):
- log_array = self._http_client.get_all_session_logs(self.id)[u'log']
+ log_array = self._http_client.get_all_session_logs(self.id)["log"]
self._logs = "\n".join(log_array)
return self._logs
@@ -225,10 +268,12 @@ def is_posted(self):
def delete(self):
session_id = self.id
- self._spark_events.emit_session_deletion_start_event(self.guid, self.kind, session_id, self.status)
+ self._spark_events.emit_session_deletion_start_event(
+ self.guid, self.kind, session_id, self.status
+ )
try:
- self.logger.debug(u"Deleting session '{}'".format(session_id))
+ self.logger.debug("Deleting session '{}'".format(session_id))
if self.status != constants.NOT_STARTED_SESSION_STATUS:
self._http_client.delete_session(session_id)
@@ -236,15 +281,27 @@ def delete(self):
self.status = constants.DEAD_SESSION_STATUS
self.id = -1
else:
- self.ipython_display.send_error(u"Cannot delete session {} that is in state '{}'."
- .format(session_id, self.status))
+ self.ipython_display.send_error(
+ "Cannot delete session {} that is in state '{}'.".format(
+ session_id, self.status
+ )
+ )
except Exception as e:
- self._spark_events.emit_session_deletion_end_event(self.guid, self.kind, session_id, self.status, False,
- e.__class__.__name__, str(e))
+ self._spark_events.emit_session_deletion_end_event(
+ self.guid,
+ self.kind,
+ session_id,
+ self.status,
+ False,
+ e.__class__.__name__,
+ str(e),
+ )
raise
else:
- self._spark_events.emit_session_deletion_end_event(self.guid, self.kind, session_id, self.status, True, "", "")
+ self._spark_events.emit_session_deletion_end_event(
+ self.guid, self.kind, session_id, self.status, True, "", ""
+ )
def wait_for_idle(self, seconds_to_wait=None):
"""Wait for session to go to idle status. Sleep meanwhile.
@@ -262,29 +319,41 @@ def wait_for_idle(self, seconds_to_wait=None):
return
if self.status in constants.FINAL_STATUS:
- error = u"Session {} unexpectedly reached final status '{}'."\
- .format(self.id, self.status)
+ error = "Session {} unexpectedly reached final status '{}'.".format(
+ self.id, self.status
+ )
self.logger.error(error)
- raise LivyUnexpectedStatusException(u'{} See logs:\n{}'.format(error, self.get_logs()))
+ raise LivyUnexpectedStatusException(
+ "{} See logs:\n{}".format(error, self.get_logs())
+ )
if seconds_to_wait <= 0.0:
- error = u"Session {} did not reach idle status in time. Current status is {}."\
- .format(self.id, self.status)
+ error = "Session {} did not reach idle status in time. Current status is {}.".format(
+ self.id, self.status
+ )
self.logger.error(error)
raise LivyClientTimeoutException(error)
- if constants.YARN_RESOURCE_LIMIT_MSG in self.session_info and \
- not self._printed_resource_warning:
- self.ipython_display.send_error(constants.RESOURCE_LIMIT_WARNING\
- .format(conf.resource_limit_mitigation_suggestion()))
+ if (
+ constants.YARN_RESOURCE_LIMIT_MSG in self.session_info
+ and not self._printed_resource_warning
+ ):
+ self.ipython_display.send_error(
+ constants.RESOURCE_LIMIT_WARNING.format(
+ conf.resource_limit_mitigation_suggestion()
+ )
+ )
self._printed_resource_warning = True
start_time = time()
sleep_time = self._policy.seconds_to_sleep(retries)
retries += 1
- self.logger.debug(u"Session {} in state {}. Sleeping {} seconds."
- .format(self.id, self.status, sleep_time))
+ self.logger.debug(
+ "Session {} in state {}. Sleeping {} seconds.".format(
+ self.id, self.status, sleep_time
+ )
+ )
sleep(sleep_time)
seconds_to_wait -= time() - start_time
@@ -295,14 +364,16 @@ def sleep(self, retries):
# Only the status will be returned as the return value.
def refresh_status_and_info(self):
response = self._http_client.get_session(self.id)
- status = response[u'state']
- log_array = response[u'log']
+ status = response["state"]
+ log_array = response["log"]
if status in constants.POSSIBLE_SESSION_STATUS:
self.status = status
- self.session_info = u"\n".join(log_array)
+ self.session_info = "\n".join(log_array)
else:
- raise LivyUnexpectedStatusException(u"Status '{}' not supported by session.".format(status))
+ raise LivyUnexpectedStatusException(
+ "Status '{}' not supported by session.".format(status)
+ )
def _start_heartbeat_thread(self):
if self._should_heartbeat and self._heartbeat_thread is None:
@@ -310,7 +381,9 @@ def _start_heartbeat_thread(self):
retry_seconds = conf.heartbeat_retry_seconds()
if self._user_passed_heartbeat_thread is None:
- self._heartbeat_thread = _HeartbeatThread(self, refresh_seconds, retry_seconds)
+ self._heartbeat_thread = _HeartbeatThread(
+ self, refresh_seconds, retry_seconds
+ )
else:
self._heartbeat_thread = self._user_passed_heartbeat_thread
@@ -323,15 +396,22 @@ def _stop_heartbeat_thread(self):
self._heartbeat_thread = None
def get_row_html(self, current_session_id):
- return u"""
{0} | {1} | {2} | {3} | {4} | {5} | {6} | {7} |
""".format(
- self.id, self.get_app_id(), self.kind, self.status,
- self.get_html_link(u'Link', self.get_spark_ui_url()), self.get_html_link(u'Link', self.get_driver_log_url()),
- self.get_user(), u"" if current_session_id is None or current_session_id != self.id else u"\u2714"
+ return """{0} | {1} | {2} | {3} | {4} | {5} | {6} | {7} |
""".format(
+ self.id,
+ self.get_app_id(),
+ self.kind,
+ self.status,
+ self.get_html_link("Link", self.get_spark_ui_url()),
+ self.get_html_link("Link", self.get_driver_log_url()),
+ self.get_user(),
+ ""
+ if current_session_id is None or current_session_id != self.id
+ else "\u2714",
)
@staticmethod
def get_html_link(text, url):
if url is not None:
- return u"""{0}""".format(text, url)
+ return """{0}""".format(text, url)
else:
- return u""
+ return ""
diff --git a/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py b/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py
index 75786347b..29f3906da 100644
--- a/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py
+++ b/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py
@@ -7,6 +7,7 @@
from sparkmagic.utils.sparklogger import SparkLog
from .exceptions import HttpClientException
+
class ReliableHttpClient(object):
"""Http client that is reliable in its requests. Uses requests library."""
@@ -16,49 +17,71 @@ def __init__(self, endpoint, headers, retry_policy):
self._retry_policy = retry_policy
self._auth = self._endpoint.auth
self._session = requests.Session()
- self.logger = SparkLog(u"ReliableHttpClient")
+ self.logger = SparkLog("ReliableHttpClient")
self.verify_ssl = not conf.ignore_ssl_errors()
if not self.verify_ssl:
- self.logger.debug(u"ATTENTION: Will ignore SSL errors. This might render you vulnerable to attacks.")
+ self.logger.debug(
+ "ATTENTION: Will ignore SSL errors. This might render you vulnerable to attacks."
+ )
requests.packages.urllib3.disable_warnings()
def get_headers(self):
return self._headers
def compose_url(self, relative_url):
- r_u = "/{}".format(relative_url.rstrip(u"/").lstrip(u"/"))
+ r_u = "/{}".format(relative_url.rstrip("/").lstrip("/"))
return self._endpoint.url + r_u
def get(self, relative_url, accepted_status_codes):
"""Sends a get request. Returns a response."""
- return self._send_request(relative_url, accepted_status_codes, self._session.get)
+ return self._send_request(
+ relative_url, accepted_status_codes, self._session.get
+ )
def post(self, relative_url, accepted_status_codes, data):
"""Sends a post request. Returns a response."""
- return self._send_request(relative_url, accepted_status_codes, self._session.post, data)
+ return self._send_request(
+ relative_url, accepted_status_codes, self._session.post, data
+ )
def delete(self, relative_url, accepted_status_codes):
"""Sends a delete request. Returns a response."""
- return self._send_request(relative_url, accepted_status_codes, self._session.delete)
+ return self._send_request(
+ relative_url, accepted_status_codes, self._session.delete
+ )
def _send_request(self, relative_url, accepted_status_codes, function, data=None):
- return self._send_request_helper(self.compose_url(relative_url), accepted_status_codes, function, data, 0)
+ return self._send_request_helper(
+ self.compose_url(relative_url), accepted_status_codes, function, data, 0
+ )
- def _send_request_helper(self, url, accepted_status_codes, function, data, retry_count):
+ def _send_request_helper(
+ self, url, accepted_status_codes, function, data, retry_count
+ ):
while True:
try:
if data is None:
- r = function(url, headers=self._headers, auth=self._auth, verify=self.verify_ssl)
+ r = function(
+ url,
+ headers=self._headers,
+ auth=self._auth,
+ verify=self.verify_ssl,
+ )
else:
- r = function(url, headers=self._headers, auth=self._auth,
- data=json.dumps(data), verify=self.verify_ssl)
+ r = function(
+ url,
+ headers=self._headers,
+ auth=self._auth,
+ data=json.dumps(data),
+ verify=self.verify_ssl,
+ )
except requests.exceptions.RequestException as e:
error = True
r = None
status = None
text = None
- self.logger.error(u"Request to '{}' failed with '{}'".format(url, e))
+ self.logger.error("Request to '{}' failed with '{}'".format(url, e))
else:
error = False
status = r.status_code
@@ -71,8 +94,13 @@ def _send_request_helper(self, url, accepted_status_codes, function, data, retry
continue
if error:
- raise HttpClientException(u"Error sending http request and maximum retry encountered.")
+ raise HttpClientException(
+ "Error sending http request and maximum retry encountered."
+ )
else:
- raise HttpClientException(u"Invalid status code '{}' from {} with error payload: {}"
- .format(status, url, text))
+ raise HttpClientException(
+ "Invalid status code '{}' from {} with error payload: {}".format(
+ status, url, text
+ )
+ )
return r
diff --git a/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py b/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py
index b145780bf..3e7f26c04 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py
@@ -9,10 +9,11 @@
import pandas as pd
+
class SendPandasDfToSparkCommand(SendToSparkCommand):
# convert unicode to utf8 or pyspark will mark data as corrupted(and deserialize incorrectly)
- _python_decode = u'''
+ _python_decode = """
import sys
import json
@@ -37,19 +38,25 @@ def _byteify(data, ignore_dicts = False):
for key, value in data.iteritems()
}
return data
- '''
-
- def __init__(self, input_variable_name, input_variable_value, output_variable_name, max_rows):
- super(SendPandasDfToSparkCommand, self).__init__(input_variable_name, input_variable_value, output_variable_name)
+ """
+
+ def __init__(
+ self, input_variable_name, input_variable_value, output_variable_name, max_rows
+ ):
+ super(SendPandasDfToSparkCommand, self).__init__(
+ input_variable_name, input_variable_value, output_variable_name
+ )
self.max_rows = max_rows
def _scala_command(self, input_variable_name, pandas_df, output_variable_name):
self._assert_input_is_pandas_dataframe(input_variable_name, pandas_df)
pandas_json = self._get_dataframe_as_json(pandas_df)
- scala_code = u'''
+ scala_code = '''
val rdd_json_array = spark.sparkContext.makeRDD("""{}""" :: Nil)
- val {} = spark.read.json(rdd_json_array)'''.format(pandas_json, output_variable_name)
+ val {} = spark.read.json(rdd_json_array)'''.format(
+ pandas_json, output_variable_name
+ )
return Command(scala_code)
@@ -60,10 +67,12 @@ def _pyspark_command(self, input_variable_name, pandas_df, output_variable_name)
pandas_json = self._get_dataframe_as_json(pandas_df)
- pyspark_code += u'''
+ pyspark_code += """
json_array = json_loads_byteified('{}')
rdd_json_array = spark.sparkContext.parallelize(json_array)
- {} = spark.read.json(rdd_json_array)'''.format(pandas_json, output_variable_name)
+ {} = spark.read.json(rdd_json_array)""".format(
+ pandas_json, output_variable_name
+ )
return Command(pyspark_code)
@@ -71,20 +80,28 @@ def _r_command(self, input_variable_name, pandas_df, output_variable_name):
self._assert_input_is_pandas_dataframe(input_variable_name, pandas_df)
pandas_json = self._get_dataframe_as_json(pandas_df)
- r_code = u'''
+ r_code = """
fileConn<-file("temporary_pandas_df_sparkmagics.txt")
writeLines('{}', fileConn)
close(fileConn)
{} <- read.json("temporary_pandas_df_sparkmagics.txt")
{}.persist()
- file.remove("temporary_pandas_df_sparkmagics.txt")'''.format(pandas_json, output_variable_name, output_variable_name)
+ file.remove("temporary_pandas_df_sparkmagics.txt")""".format(
+ pandas_json, output_variable_name, output_variable_name
+ )
return Command(r_code)
def _get_dataframe_as_json(self, pandas_df):
- return pandas_df.head(self.max_rows).to_json(orient=u'records')
+ return pandas_df.head(self.max_rows).to_json(orient="records")
- def _assert_input_is_pandas_dataframe(self, input_variable_name, input_variable_value):
+ def _assert_input_is_pandas_dataframe(
+ self, input_variable_name, input_variable_value
+ ):
if not isinstance(input_variable_value, pd.DataFrame):
wrong_type = input_variable_value.__class__.__name__
- raise BadUserDataException(u'{} is not a Pandas DataFrame! Got {} instead.'.format(input_variable_name, wrong_type))
+ raise BadUserDataException(
+ "{} is not a Pandas DataFrame! Got {} instead.".format(
+ input_variable_name, wrong_type
+ )
+ )
diff --git a/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py b/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py
index b2b861a83..29ea3dee8 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py
@@ -5,25 +5,43 @@
from sparkmagic.livyclientlib.command import Command
from sparkmagic.livyclientlib.exceptions import BadUserDataException
-class SendStringToSparkCommand(SendToSparkCommand):
- def _scala_command(self, input_variable_name, input_variable_value, output_variable_name):
+class SendStringToSparkCommand(SendToSparkCommand):
+ def _scala_command(
+ self, input_variable_name, input_variable_value, output_variable_name
+ ):
self._assert_input_is_string_type(input_variable_name, input_variable_value)
- scala_code = u'var {} = """{}"""'.format(output_variable_name, input_variable_value)
+ scala_code = 'var {} = """{}"""'.format(
+ output_variable_name, input_variable_value
+ )
return Command(scala_code)
- def _pyspark_command(self, input_variable_name, input_variable_value, output_variable_name):
+ def _pyspark_command(
+ self, input_variable_name, input_variable_value, output_variable_name
+ ):
self._assert_input_is_string_type(input_variable_name, input_variable_value)
- pyspark_code = u'{} = {}'.format(output_variable_name, repr(input_variable_value))
+ pyspark_code = "{} = {}".format(
+ output_variable_name, repr(input_variable_value)
+ )
return Command(pyspark_code)
- def _r_command(self, input_variable_name, input_variable_value, output_variable_name):
+ def _r_command(
+ self, input_variable_name, input_variable_value, output_variable_name
+ ):
self._assert_input_is_string_type(input_variable_name, input_variable_value)
- escaped_input_variable_value = input_variable_value.replace(u'\\', u'\\\\').replace(u'"',u'\\"')
- r_code = u'''assign("{}","{}")'''.format(output_variable_name, escaped_input_variable_value)
+ escaped_input_variable_value = input_variable_value.replace(
+ "\\", "\\\\"
+ ).replace('"', '\\"')
+ r_code = """assign("{}","{}")""".format(
+ output_variable_name, escaped_input_variable_value
+ )
return Command(r_code)
def _assert_input_is_string_type(self, input_variable_name, input_variable_value):
if not isinstance(input_variable_value, str):
wrong_type = input_variable_value.__class__.__name__
- raise BadUserDataException(u'{} is not a str or bytes! Got {} instead'.format(input_variable_name, wrong_type))
+ raise BadUserDataException(
+ "{} is not a str or bytes! Got {} instead".format(
+ input_variable_name, wrong_type
+ )
+ )
diff --git a/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py b/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py
index 5aa45f5a5..cc2cb729c 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py
@@ -7,8 +7,15 @@
from abc import abstractmethod
+
class SendToSparkCommand(Command):
- def __init__(self, input_variable_name, input_variable_value, output_variable_name, spark_events=None):
+ def __init__(
+ self,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ spark_events=None,
+ ):
super(SendToSparkCommand, self).__init__("", spark_events)
self.input_variable_name = input_variable_name
self.input_variable_value = input_variable_value
@@ -16,29 +23,48 @@ def __init__(self, input_variable_name, input_variable_value, output_variable_na
def execute(self, session):
try:
- command = self.to_command(session.kind, self.input_variable_name, self.input_variable_value, self.output_variable_name)
+ command = self.to_command(
+ session.kind,
+ self.input_variable_name,
+ self.input_variable_value,
+ self.output_variable_name,
+ )
return command.execute(session)
except Exception as e:
raise e
- def to_command(self, kind, input_variable_name, input_variable_value, output_variable_name):
+ def to_command(
+ self, kind, input_variable_name, input_variable_value, output_variable_name
+ ):
if kind == constants.SESSION_KIND_PYSPARK:
- return self._pyspark_command(input_variable_name, input_variable_value, output_variable_name)
+ return self._pyspark_command(
+ input_variable_name, input_variable_value, output_variable_name
+ )
elif kind == constants.SESSION_KIND_SPARK:
- return self._scala_command(input_variable_name, input_variable_value, output_variable_name)
+ return self._scala_command(
+ input_variable_name, input_variable_value, output_variable_name
+ )
elif kind == constants.SESSION_KIND_SPARKR:
- return self._r_command(input_variable_name, input_variable_value, output_variable_name)
+ return self._r_command(
+ input_variable_name, input_variable_value, output_variable_name
+ )
else:
- raise BadUserDataException(u"Kind '{}' is not supported.".format(kind))
+ raise BadUserDataException("Kind '{}' is not supported.".format(kind))
@abstractmethod
- def _scala_command(self, input_variable_name, input_variable_value, output_variable_name):
- raise NotImplementedError #override and provide proper implementation in supertype!
+ def _scala_command(
+ self, input_variable_name, input_variable_value, output_variable_name
+ ):
+ raise NotImplementedError # override and provide proper implementation in supertype!
@abstractmethod
- def _pyspark_command(self, input_variable_name, input_variable_value, output_variable_name):
- raise NotImplementedError #override and provide proper implementation in supertype!
+ def _pyspark_command(
+ self, input_variable_name, input_variable_value, output_variable_name
+ ):
+ raise NotImplementedError # override and provide proper implementation in supertype!
@abstractmethod
- def _r_command(self, input_variable_name, input_variable_value, output_variable_name):
- raise NotImplementedError #override and provide proper implementation in supertype!
+ def _r_command(
+ self, input_variable_name, input_variable_value, output_variable_name
+ ):
+ raise NotImplementedError # override and provide proper implementation in supertype!
diff --git a/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py b/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py
index 9ba850dd0..2004587c4 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py
@@ -9,7 +9,7 @@
class SessionManager(object):
def __init__(self, ipython_display):
- self.logger = SparkLog(u"SessionManager")
+ self.logger = SparkLog("SessionManager")
self.ipython_display = ipython_display
self._sessions = dict()
@@ -24,12 +24,17 @@ def get_sessions_list(self):
return list(self._sessions.keys())
def get_sessions_info(self):
- return [u"Name: {}\t{}".format(k, str(self._sessions[k])) for k in list(self._sessions.keys())]
+ return [
+ "Name: {}\t{}".format(k, str(self._sessions[k]))
+ for k in list(self._sessions.keys())
+ ]
def add_session(self, name, session):
if name in self._sessions:
- raise SessionManagementException(u"Session with name '{}' already exists. Please delete the session"
- u" first if you intend to replace it.".format(name))
+ raise SessionManagementException(
+ "Session with name '{}' already exists. Please delete the session"
+ " first if you intend to replace it.".format(name)
+ )
self._sessions[name] = session
@@ -39,16 +44,24 @@ def get_any_session(self):
key = self.get_sessions_list()[0]
return self._sessions[key]
elif number_of_sessions == 0:
- raise SessionManagementException(u"You need to have at least 1 client created to execute commands.")
+ raise SessionManagementException(
+ "You need to have at least 1 client created to execute commands."
+ )
else:
- raise SessionManagementException(u"Please specify the client to use. Possible sessions are {}".format(
- self.get_sessions_list()))
-
+ raise SessionManagementException(
+ "Please specify the client to use. Possible sessions are {}".format(
+ self.get_sessions_list()
+ )
+ )
+
def get_session(self, name):
if name in self._sessions:
return self._sessions[name]
- raise SessionManagementException(u"Could not find '{}' session in list of saved sessions. Possible sessions are {}".format(
- name, self.get_sessions_list()))
+ raise SessionManagementException(
+ "Could not find '{}' session in list of saved sessions. Possible sessions are {}".format(
+ name, self.get_sessions_list()
+ )
+ )
def get_session_id_for_client(self, name):
if name in self.get_sessions_list():
@@ -63,7 +76,7 @@ def get_session_name_by_id_endpoint(self, id, endpoint):
def delete_client(self, name):
self._remove_session(name)
-
+
def clean_up_all(self):
for name in self.get_sessions_list():
self._remove_session(name)
@@ -73,19 +86,26 @@ def _remove_session(self, name):
self._sessions[name].delete()
del self._sessions[name]
else:
- raise SessionManagementException(u"Could not find '{}' session in list of saved sessions. Possible sessions are {}"
- .format(name, self.get_sessions_list()))
+ raise SessionManagementException(
+ "Could not find '{}' session in list of saved sessions. Possible sessions are {}".format(
+ name, self.get_sessions_list()
+ )
+ )
def _register_cleanup_on_exit(self):
"""
Stop the livy sessions before python process exits for any reason (if enabled in conf)
"""
if conf.cleanup_all_sessions_on_exit():
+
def cleanup_spark_sessions():
try:
self.clean_up_all()
except Exception as e:
- self.logger.error(u"Error cleaning up sessions on exit: {}".format(e))
+ self.logger.error(
+ "Error cleaning up sessions on exit: {}".format(e)
+ )
pass
+
atexit.register(cleanup_spark_sessions)
- self.ipython_display.writeln(u"Cleaning up livy sessions on exit is enabled")
+ self.ipython_display.writeln("Cleaning up livy sessions on exit is enabled")
diff --git a/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py b/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py
index da8ec6298..419ee08f0 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py
@@ -10,9 +10,8 @@
class SparkController(object):
-
def __init__(self, ipython_display):
- self.logger = SparkLog(u"SparkController")
+ self.logger = SparkLog("SparkController")
self.ipython_display = ipython_display
self.session_manager = SessionManager(ipython_display)
# this is to reuse the already created http clients
@@ -45,11 +44,22 @@ def run_sqlquery(self, sqlquery, client_name=None):
def get_all_sessions_endpoint(self, endpoint):
http_client = self._http_client(endpoint)
- sessions = http_client.get_sessions()[u"sessions"]
- supported_sessions = filter(lambda s: (s[constants.LIVY_KIND_PARAM] in constants.SESSION_KINDS_SUPPORTED), sessions)
- session_list = [self._livy_session(http_client, {constants.LIVY_KIND_PARAM: s[constants.LIVY_KIND_PARAM]},
- self.ipython_display, s[u"id"])
- for s in supported_sessions]
+ sessions = http_client.get_sessions()["sessions"]
+ supported_sessions = filter(
+ lambda s: (
+ s[constants.LIVY_KIND_PARAM] in constants.SESSION_KINDS_SUPPORTED
+ ),
+ sessions,
+ )
+ session_list = [
+ self._livy_session(
+ http_client,
+ {constants.LIVY_KIND_PARAM: s[constants.LIVY_KIND_PARAM]},
+ self.ipython_display,
+ s["id"],
+ )
+ for s in supported_sessions
+ ]
for s in session_list:
s.refresh_status_and_info()
return session_list
@@ -69,7 +79,9 @@ def delete_session_by_name(self, name):
self.session_manager.delete_client(name)
def delete_session_by_id(self, endpoint, session_id):
- name = self.session_manager.get_session_name_by_id_endpoint(session_id, endpoint)
+ name = self.session_manager.get_session_name_by_id_endpoint(
+ session_id, endpoint
+ )
if name in self.session_manager.get_sessions_list():
self.delete_session_by_name(name)
@@ -77,13 +89,21 @@ def delete_session_by_id(self, endpoint, session_id):
http_client = self._http_client(endpoint)
response = http_client.get_session(session_id)
http_client = self._http_client(endpoint)
- session = self._livy_session(http_client, {constants.LIVY_KIND_PARAM: response[constants.LIVY_KIND_PARAM]},
- self.ipython_display, session_id)
+ session = self._livy_session(
+ http_client,
+ {constants.LIVY_KIND_PARAM: response[constants.LIVY_KIND_PARAM]},
+ self.ipython_display,
+ session_id,
+ )
session.delete()
def add_session(self, name, endpoint, skip_if_exists, properties):
if skip_if_exists and (name in self.session_manager.get_sessions_list()):
- self.logger.debug(u"Skipping {} because it already exists in list of sessions.".format(name))
+ self.logger.debug(
+ "Skipping {} because it already exists in list of sessions.".format(
+ name
+ )
+ )
return
http_client = self._http_client(endpoint)
session = self._livy_session(http_client, properties, self.ipython_display)
@@ -97,7 +117,6 @@ def add_session(self, name, endpoint, skip_if_exists, properties):
else:
self.session_manager.add_session(name, session)
-
def get_session_id_for_client(self, name):
return self.session_manager.get_session_id_for_client(name)
@@ -117,12 +136,18 @@ def get_managed_clients(self):
return self.session_manager.sessions
@staticmethod
- def _livy_session(http_client, properties, ipython_display,
- session_id=-1):
- return LivySession(http_client, properties, ipython_display,
- session_id, heartbeat_timeout=conf.livy_server_heartbeat_timeout_seconds())
+ def _livy_session(http_client, properties, ipython_display, session_id=-1):
+ return LivySession(
+ http_client,
+ properties,
+ ipython_display,
+ session_id,
+ heartbeat_timeout=conf.livy_server_heartbeat_timeout_seconds(),
+ )
def _http_client(self, endpoint):
if endpoint not in self._http_clients:
- self._http_clients[endpoint] = LivyReliableHttpClient.from_endpoint(endpoint)
+ self._http_clients[endpoint] = LivyReliableHttpClient.from_endpoint(
+ endpoint
+ )
return self._http_clients[endpoint]
diff --git a/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py b/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py
index 2b5919fc3..6bcc836d0 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py
@@ -5,13 +5,25 @@
import sparkmagic.utils.configuration as conf
from sparkmagic.utils.sparkevents import SparkEvents
from sparkmagic.livyclientlib.command import Command
-from sparkmagic.livyclientlib.exceptions import DataFrameParseException, BadUserDataException
+from sparkmagic.livyclientlib.exceptions import (
+ DataFrameParseException,
+ BadUserDataException,
+)
import sparkmagic.utils.constants as constants
import ast
+
class SparkStoreCommand(Command):
- def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None, coerce=None):
+ def __init__(
+ self,
+ output_var,
+ samplemethod=None,
+ maxrows=None,
+ samplefraction=None,
+ spark_events=None,
+ coerce=None,
+ ):
super(SparkStoreCommand, self).__init__("", spark_events)
if samplemethod is None:
@@ -21,12 +33,16 @@ def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=N
if samplefraction is None:
samplefraction = conf.default_samplefraction()
- if samplemethod not in {u'take', u'sample'}:
- raise BadUserDataException(u'samplemethod (-m) must be one of (take, sample)')
+ if samplemethod not in {"take", "sample"}:
+ raise BadUserDataException(
+ "samplemethod (-m) must be one of (take, sample)"
+ )
if not isinstance(maxrows, int):
- raise BadUserDataException(u'maxrows (-n) must be an integer')
+ raise BadUserDataException("maxrows (-n) must be an integer")
if not 0.0 <= samplefraction <= 1.0:
- raise BadUserDataException(u'samplefraction (-r) must be a float between 0.0 and 1.0')
+ raise BadUserDataException(
+ "samplefraction (-r) must be a float between 0.0 and 1.0"
+ )
self.samplemethod = samplemethod
self.maxrows = maxrows
@@ -37,7 +53,6 @@ def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=N
self._spark_events = spark_events
self._coerce = coerce
-
def execute(self, session):
try:
command = self.to_command(session.kind, self.output_var)
@@ -50,7 +65,6 @@ def execute(self, session):
else:
return result
-
def to_command(self, kind, spark_context_variable_name):
if kind == constants.SESSION_KIND_PYSPARK:
return self._pyspark_command(spark_context_variable_name)
@@ -59,64 +73,63 @@ def to_command(self, kind, spark_context_variable_name):
elif kind == constants.SESSION_KIND_SPARKR:
return self._r_command(spark_context_variable_name)
else:
- raise BadUserDataException(u"Kind '{}' is not supported.".format(kind))
-
+ raise BadUserDataException("Kind '{}' is not supported.".format(kind))
def _pyspark_command(self, spark_context_variable_name):
# use_unicode=False means the result will be UTF-8-encoded bytes, so we
# set it to False for Python 2.
- command = u'{}.toJSON(use_unicode=(sys.version_info.major > 2))'.format(
- spark_context_variable_name)
- if self.samplemethod == u'sample':
- command = u'{}.sample(False, {})'.format(command, self.samplefraction)
+ command = "{}.toJSON(use_unicode=(sys.version_info.major > 2))".format(
+ spark_context_variable_name
+ )
+ if self.samplemethod == "sample":
+ command = "{}.sample(False, {})".format(command, self.samplefraction)
if self.maxrows >= 0:
- command = u'{}.take({})'.format(command, self.maxrows)
+ command = "{}.take({})".format(command, self.maxrows)
else:
- command = u'{}.collect()'.format(command)
+ command = "{}.collect()".format(command)
# Unicode support has improved in Python 3 so we don't need to encode.
print_command = constants.LONG_RANDOM_VARIABLE_NAME
- command = u'import sys\nfor {} in {}: print({})'.format(
- constants.LONG_RANDOM_VARIABLE_NAME,
- command,
- print_command)
+ command = "import sys\nfor {} in {}: print({})".format(
+ constants.LONG_RANDOM_VARIABLE_NAME, command, print_command
+ )
return Command(command)
-
def _scala_command(self, spark_context_variable_name):
- command = u'{}.toJSON'.format(spark_context_variable_name)
- if self.samplemethod == u'sample':
- command = u'{}.sample(false, {})'.format(command, self.samplefraction)
+ command = "{}.toJSON".format(spark_context_variable_name)
+ if self.samplemethod == "sample":
+ command = "{}.sample(false, {})".format(command, self.samplefraction)
if self.maxrows >= 0:
- command = u'{}.take({})'.format(command, self.maxrows)
+ command = "{}.take({})".format(command, self.maxrows)
else:
- command = u'{}.collect'.format(command)
- return Command(u'{}.foreach(println)'.format(command))
-
+ command = "{}.collect".format(command)
+ return Command("{}.foreach(println)".format(command))
def _r_command(self, spark_context_variable_name):
command = spark_context_variable_name
- if self.samplemethod == u'sample':
- command = u'sample({}, FALSE, {})'.format(command,
- self.samplefraction)
+ if self.samplemethod == "sample":
+ command = "sample({}, FALSE, {})".format(command, self.samplefraction)
if self.maxrows >= 0:
- command = u'take({},{})'.format(command, self.maxrows)
+ command = "take({},{})".format(command, self.maxrows)
else:
- command = u'collect({})'.format(command)
- command = u'jsonlite::toJSON({})'.format(command)
- command = u'for ({} in ({})) {{cat({})}}'.format(constants.LONG_RANDOM_VARIABLE_NAME,
- command,
- constants.LONG_RANDOM_VARIABLE_NAME)
+ command = "collect({})".format(command)
+ command = "jsonlite::toJSON({})".format(command)
+ command = "for ({} in ({})) {{cat({})}}".format(
+ constants.LONG_RANDOM_VARIABLE_NAME,
+ command,
+ constants.LONG_RANDOM_VARIABLE_NAME,
+ )
return Command(command)
-
# Used only for unit testing
def __eq__(self, other):
- return self.code == other.code and \
- self.samplemethod == other.samplemethod and \
- self.maxrows == other.maxrows and \
- self.samplefraction == other.samplefraction and \
- self.output_var == other.output_var and \
- self._coerce == other._coerce
+ return (
+ self.code == other.code
+ and self.samplemethod == other.samplemethod
+ and self.maxrows == other.maxrows
+ and self.samplefraction == other.samplefraction
+ and self.output_var == other.output_var
+ and self._coerce == other._coerce
+ )
def __ne__(self, other):
return not (self == other)
diff --git a/sparkmagic/sparkmagic/livyclientlib/sqlquery.py b/sparkmagic/sparkmagic/livyclientlib/sqlquery.py
index 6f460810b..3e821c85c 100644
--- a/sparkmagic/sparkmagic/livyclientlib/sqlquery.py
+++ b/sparkmagic/sparkmagic/livyclientlib/sqlquery.py
@@ -1,6 +1,9 @@
from hdijupyterutils.guid import ObjectWithGuid
-from sparkmagic.utils.utils import coerce_pandas_df_to_numeric_datetime, records_to_dataframe
+from sparkmagic.utils.utils import (
+ coerce_pandas_df_to_numeric_datetime,
+ records_to_dataframe,
+)
import sparkmagic.utils.configuration as conf
import sparkmagic.utils.constants as constants
from sparkmagic.utils.sparkevents import SparkEvents
@@ -9,9 +12,17 @@
class SQLQuery(ObjectWithGuid):
- def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None, coerce=None):
+ def __init__(
+ self,
+ query,
+ samplemethod=None,
+ maxrows=None,
+ samplefraction=None,
+ spark_events=None,
+ coerce=None,
+ ):
super(SQLQuery, self).__init__()
-
+
if samplemethod is None:
samplemethod = conf.default_samplemethod()
if maxrows is None:
@@ -19,12 +30,16 @@ def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None,
if samplefraction is None:
samplefraction = conf.default_samplefraction()
- if samplemethod not in {u'take', u'sample'}:
- raise BadUserDataException(u'samplemethod (-m) must be one of (take, sample)')
+ if samplemethod not in {"take", "sample"}:
+ raise BadUserDataException(
+ "samplemethod (-m) must be one of (take, sample)"
+ )
if not isinstance(maxrows, int):
- raise BadUserDataException(u'maxrows (-n) must be an integer')
+ raise BadUserDataException("maxrows (-n) must be an integer")
if not 0.0 <= samplefraction <= 1.0:
- raise BadUserDataException(u'samplefraction (-r) must be a float between 0.0 and 1.0')
+ raise BadUserDataException(
+ "samplefraction (-r) must be a float between 0.0 and 1.0"
+ )
self.query = query
self.samplemethod = samplemethod
@@ -43,12 +58,19 @@ def to_command(self, kind, sql_context_variable_name):
elif kind == constants.SESSION_KIND_SPARKR:
return self._r_command(sql_context_variable_name)
else:
- raise BadUserDataException(u"Kind '{}' is not supported.".format(kind))
+ raise BadUserDataException("Kind '{}' is not supported.".format(kind))
def execute(self, session):
- self._spark_events.emit_sql_execution_start_event(session.guid, session.kind, session.id, self.guid,
- self.samplemethod, self.maxrows, self.samplefraction)
- command_guid = ''
+ self._spark_events.emit_sql_execution_start_event(
+ session.guid,
+ session.kind,
+ session.id,
+ self.guid,
+ self.samplemethod,
+ self.maxrows,
+ self.samplefraction,
+ )
+ command_guid = ""
try:
command = self.to_command(session.kind, session.sql_context_variable_name)
command_guid = command.guid
@@ -57,66 +79,89 @@ def execute(self, session):
raise BadUserDataException(records_text)
result = records_to_dataframe(records_text, session.kind, self._coerce)
except Exception as e:
- self._spark_events.emit_sql_execution_end_event(session.guid, session.kind, session.id, self.guid,
- command_guid, False, e.__class__.__name__, str(e))
+ self._spark_events.emit_sql_execution_end_event(
+ session.guid,
+ session.kind,
+ session.id,
+ self.guid,
+ command_guid,
+ False,
+ e.__class__.__name__,
+ str(e),
+ )
raise
else:
- self._spark_events.emit_sql_execution_end_event(session.guid, session.kind, session.id, self.guid,
- command_guid, True, "", "")
+ self._spark_events.emit_sql_execution_end_event(
+ session.guid,
+ session.kind,
+ session.id,
+ self.guid,
+ command_guid,
+ True,
+ "",
+ "",
+ )
return result
-
def _pyspark_command(self, sql_context_variable_name):
# use_unicode=False means the result will be UTF-8-encoded bytes, so we
# set it to False for Python 2.
- command = u'{}.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2))'.format(
- sql_context_variable_name, self.query)
- if self.samplemethod == u'sample':
- command = u'{}.sample(False, {})'.format(command, self.samplefraction)
+ command = '{}.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2))'.format(
+ sql_context_variable_name, self.query
+ )
+ if self.samplemethod == "sample":
+ command = "{}.sample(False, {})".format(command, self.samplefraction)
if self.maxrows >= 0:
- command = u'{}.take({})'.format(command, self.maxrows)
+ command = "{}.take({})".format(command, self.maxrows)
else:
- command = u'{}.collect()'.format(command)
+ command = "{}.collect()".format(command)
print_command = constants.LONG_RANDOM_VARIABLE_NAME
- command = u'import sys\nfor {} in {}: print({})'.format(
- constants.LONG_RANDOM_VARIABLE_NAME,
- command,
- print_command)
+ command = "import sys\nfor {} in {}: print({})".format(
+ constants.LONG_RANDOM_VARIABLE_NAME, command, print_command
+ )
return Command(command)
def _scala_command(self, sql_context_variable_name):
- command = u'{}.sql("""{}""").toJSON'.format(sql_context_variable_name, self.query)
- if self.samplemethod == u'sample':
- command = u'{}.sample(false, {})'.format(command, self.samplefraction)
+ command = '{}.sql("""{}""").toJSON'.format(
+ sql_context_variable_name, self.query
+ )
+ if self.samplemethod == "sample":
+ command = "{}.sample(false, {})".format(command, self.samplefraction)
if self.maxrows >= 0:
- command = u'{}.take({})'.format(command, self.maxrows)
+ command = "{}.take({})".format(command, self.maxrows)
else:
- command = u'{}.collect'.format(command)
- return Command(u'{}.foreach(println)'.format(command))
+ command = "{}.collect".format(command)
+ return Command("{}.foreach(println)".format(command))
def _r_command(self, sql_context_variable_name):
- if sql_context_variable_name == 'spark':
- command = u'sql("{}")'.format(self.query)
+ if sql_context_variable_name == "spark":
+ command = 'sql("{}")'.format(self.query)
else:
- command = u'sql({}, "{}")'.format(sql_context_variable_name, self.query)
- if self.samplemethod == u'sample':
- command = u'sample({}, FALSE, {})'.format(command, self.samplefraction)
+ command = 'sql({}, "{}")'.format(sql_context_variable_name, self.query)
+ if self.samplemethod == "sample":
+ command = "sample({}, FALSE, {})".format(command, self.samplefraction)
if self.maxrows >= 0:
- command = u'take({},{})'.format(command, self.maxrows)
+ command = "take({},{})".format(command, self.maxrows)
else:
- command = u'collect({})'.format(command)
- command = u'jsonlite:::toJSON({})'.format(command)
- command = u'for ({} in ({})) {{cat({})}}'.format(constants.LONG_RANDOM_VARIABLE_NAME, command, constants.LONG_RANDOM_VARIABLE_NAME)
+ command = "collect({})".format(command)
+ command = "jsonlite:::toJSON({})".format(command)
+ command = "for ({} in ({})) {{cat({})}}".format(
+ constants.LONG_RANDOM_VARIABLE_NAME,
+ command,
+ constants.LONG_RANDOM_VARIABLE_NAME,
+ )
return Command(command)
# Used only for unit testing
def __eq__(self, other):
- return self.query == other.query and \
- self.samplemethod == other.samplemethod and \
- self.maxrows == other.maxrows and \
- self.samplefraction == other.samplefraction and \
- self._coerce == other._coerce
+ return (
+ self.query == other.query
+ and self.samplemethod == other.samplemethod
+ and self.maxrows == other.maxrows
+ and self.samplefraction == other.samplefraction
+ and self._coerce == other._coerce
+ )
def __ne__(self, other):
return not (self == other)
diff --git a/sparkmagic/sparkmagic/magics/remotesparkmagics.py b/sparkmagic/sparkmagic/magics/remotesparkmagics.py
index 8c6fffcb8..eb3a1b3b7 100644
--- a/sparkmagic/sparkmagic/magics/remotesparkmagics.py
+++ b/sparkmagic/sparkmagic/magics/remotesparkmagics.py
@@ -12,8 +12,18 @@
from hdijupyterutils.ipywidgetfactory import IpyWidgetFactory
import sparkmagic.utils.configuration as conf
-from sparkmagic.utils.utils import parse_argstring_or_throw, get_coerce_value, initialize_auth
-from sparkmagic.utils.constants import CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, LANG_PYTHON, LANG_R, LANG_SCALA
+from sparkmagic.utils.utils import (
+ parse_argstring_or_throw,
+ get_coerce_value,
+ initialize_auth,
+)
+from sparkmagic.utils.constants import (
+ CONTEXT_NAME_SPARK,
+ CONTEXT_NAME_SQL,
+ LANG_PYTHON,
+ LANG_R,
+ LANG_SCALA,
+)
from sparkmagic.controllerwidget.magicscontrollerwidget import MagicsControllerWidget
from sparkmagic.livyclientlib.endpoint import Endpoint
from sparkmagic.magics.sparkmagicsbase import SparkMagicBase
@@ -28,7 +38,9 @@ def __init__(self, shell, data=None, widget=None):
self.endpoints = {}
if widget is None:
- widget = MagicsControllerWidget(self.spark_controller, IpyWidgetFactory(), self.ipython_display)
+ widget = MagicsControllerWidget(
+ self.spark_controller, IpyWidgetFactory(), self.ipython_display
+ )
self.manage_widget = widget
@line_magic
@@ -38,80 +50,163 @@ def manage_spark(self, line, local_ns=None):
return self.manage_widget
@magic_arguments()
- @argument("-c", "--context", type=str, default=CONTEXT_NAME_SPARK,
- help="Context to use: '{}' for spark and '{}' for sql queries. "
- "Default is '{}'.".format(CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, CONTEXT_NAME_SPARK))
- @argument("-s", "--session", type=str, default=None, help="The name of the Livy session to use.")
- @argument("-o", "--output", type=str, default=None, help="If present, output when using SQL "
- "queries will be stored in this variable.")
- @argument("-q", "--quiet", type=bool, default=False, nargs="?", const=True, help="Do not display visualizations"
- " on SQL queries")
- @argument("-m", "--samplemethod", type=str, default=None, help="Sample method for SQL queries: either take or sample")
- @argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back "
- "from the server for SQL queries")
- @argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from SQL queries")
+ @argument(
+ "-c",
+ "--context",
+ type=str,
+ default=CONTEXT_NAME_SPARK,
+ help="Context to use: '{}' for spark and '{}' for sql queries. "
+ "Default is '{}'.".format(
+ CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, CONTEXT_NAME_SPARK
+ ),
+ )
+ @argument(
+ "-s",
+ "--session",
+ type=str,
+ default=None,
+ help="The name of the Livy session to use.",
+ )
+ @argument(
+ "-o",
+ "--output",
+ type=str,
+ default=None,
+ help="If present, output when using SQL "
+ "queries will be stored in this variable.",
+ )
+ @argument(
+ "-q",
+ "--quiet",
+ type=bool,
+ default=False,
+ nargs="?",
+ const=True,
+ help="Do not display visualizations" " on SQL queries",
+ )
+ @argument(
+ "-m",
+ "--samplemethod",
+ type=str,
+ default=None,
+ help="Sample method for SQL queries: either take or sample",
+ )
+ @argument(
+ "-n",
+ "--maxrows",
+ type=int,
+ default=None,
+ help="Maximum number of rows that will be pulled back "
+ "from the server for SQL queries",
+ )
+ @argument(
+ "-r",
+ "--samplefraction",
+ type=float,
+ default=None,
+ help="Sample fraction for sampling from SQL queries",
+ )
@argument("-u", "--url", type=str, default=None, help="URL for Livy endpoint")
- @argument("-a", "--user", dest='user', type=str, default="", help="Username for HTTP access to Livy endpoint")
- @argument("-p", "--password", type=str, default="", help="Password for HTTP access to Livy endpoint")
- @argument("-t", "--auth", type=str, default=None, help="Auth type for HTTP access to Livy endpoint. [Kerberos, None, Basic]")
- @argument("-l", "--language", type=str, default=None,
- help="Language for Livy session; one of {}".format(', '.join([LANG_PYTHON, LANG_SCALA, LANG_R])))
+ @argument(
+ "-a",
+ "--user",
+ dest="user",
+ type=str,
+ default="",
+ help="Username for HTTP access to Livy endpoint",
+ )
+ @argument(
+ "-p",
+ "--password",
+ type=str,
+ default="",
+ help="Password for HTTP access to Livy endpoint",
+ )
+ @argument(
+ "-t",
+ "--auth",
+ type=str,
+ default=None,
+ help="Auth type for HTTP access to Livy endpoint. [Kerberos, None, Basic]",
+ )
+ @argument(
+ "-l",
+ "--language",
+ type=str,
+ default=None,
+ help="Language for Livy session; one of {}".format(
+ ", ".join([LANG_PYTHON, LANG_SCALA, LANG_R])
+ ),
+ )
@argument("command", type=str, default=[""], nargs="*", help="Commands to execute.")
- @argument("-k", "--skip", type=bool, default=False, nargs="?", const=True, help="Skip adding session if it already exists")
+ @argument(
+ "-k",
+ "--skip",
+ type=bool,
+ default=False,
+ nargs="?",
+ const=True,
+ help="Skip adding session if it already exists",
+ )
@argument("-i", "--id", type=int, default=None, help="Session ID")
- @argument("-e", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) "
- "of the dataframe or not (pass False)")
-
+ @argument(
+ "-e",
+ "--coerce",
+ type=str,
+ default=None,
+ help="Whether to automatically coerce the types (default, pass True if being explicit) "
+ "of the dataframe or not (pass False)",
+ )
@needs_local_scope
@line_cell_magic
@handle_expected_exceptions
def spark(self, line, cell="", local_ns=None):
"""Magic to execute spark remotely.
- This magic allows you to create a Livy Scala or Python session against a Livy endpoint. Every session can
- be used to execute either Spark code or SparkSQL code by executing against the SQL context in the session.
- When the SQL context is used, the result will be a Pandas dataframe of a sample of the results.
-
- If invoked with no subcommand, the cell will be executed against the specified session.
-
- Subcommands
- -----------
- info
- Display the available Livy sessions and other configurations for sessions.
- add
- Add a Livy session given a session name (-s), language (-l), and endpoint credentials.
- The -k argument, if present, will skip adding this session if it already exists.
- e.g. `%spark add -s test -l python -u https://sparkcluster.net/livy -t Kerberos -a u -p -k`
- config
- Override the livy session properties sent to Livy on session creation. All session creations will
- contain these config settings from then on.
- Expected value is a JSON key-value string to be sent as part of the Request Body for the POST /sessions
- endpoint in Livy.
- e.g. `%%spark config`
- `{"driverMemory":"1000M", "executorCores":4}`
- run
- Run Spark code against a session.
- e.g. `%%spark -s testsession` will execute the cell code against the testsession previously created
- e.g. `%%spark -s testsession -c sql` will execute the SQL code against the testsession previously created
- e.g. `%%spark -s testsession -c sql -o my_var` will execute the SQL code against the testsession
- previously created and store the pandas dataframe created in the my_var variable in the
- Python environment.
- logs
- Returns the logs for a given session.
- e.g. `%spark logs -s testsession` will return the logs for the testsession previously created
- delete
- Delete a Livy session.
- e.g. `%spark delete -s defaultlivy`
- cleanup
- Delete all Livy sessions created by the notebook. No arguments required.
- e.g. `%spark cleanup`
+ This magic allows you to create a Livy Scala or Python session against a Livy endpoint. Every session can
+ be used to execute either Spark code or SparkSQL code by executing against the SQL context in the session.
+ When the SQL context is used, the result will be a Pandas dataframe of a sample of the results.
+
+ If invoked with no subcommand, the cell will be executed against the specified session.
+
+ Subcommands
+ -----------
+ info
+ Display the available Livy sessions and other configurations for sessions.
+ add
+ Add a Livy session given a session name (-s), language (-l), and endpoint credentials.
+ The -k argument, if present, will skip adding this session if it already exists.
+ e.g. `%spark add -s test -l python -u https://sparkcluster.net/livy -t Kerberos -a u -p -k`
+ config
+ Override the livy session properties sent to Livy on session creation. All session creations will
+ contain these config settings from then on.
+ Expected value is a JSON key-value string to be sent as part of the Request Body for the POST /sessions
+ endpoint in Livy.
+ e.g. `%%spark config`
+ `{"driverMemory":"1000M", "executorCores":4}`
+ run
+ Run Spark code against a session.
+ e.g. `%%spark -s testsession` will execute the cell code against the testsession previously created
+ e.g. `%%spark -s testsession -c sql` will execute the SQL code against the testsession previously created
+ e.g. `%%spark -s testsession -c sql -o my_var` will execute the SQL code against the testsession
+ previously created and store the pandas dataframe created in the my_var variable in the
+ Python environment.
+ logs
+ Returns the logs for a given session.
+ e.g. `%spark logs -s testsession` will return the logs for the testsession previously created
+ delete
+ Delete a Livy session.
+ e.g. `%spark delete -s defaultlivy`
+ cleanup
+ Delete all Livy sessions created by the notebook. No arguments required.
+ e.g. `%spark cleanup`
"""
usage = "Please look at usage of %spark by executing `%spark?`."
user_input = line
args = parse_argstring_or_throw(self.spark, user_input)
subcommand = args.command[0].lower()
-
+
if args.auth is None:
args.auth = conf.get_auth_value(args.user, args.password)
else:
@@ -121,7 +216,9 @@ def spark(self, line, cell="", local_ns=None):
if subcommand == "info":
if args.url is not None and args.id is not None:
endpoint = Endpoint(args.url, initialize_auth(args))
- info_sessions = self.spark_controller.get_all_sessions_endpoint_info(endpoint)
+ info_sessions = self.spark_controller.get_all_sessions_endpoint_info(
+ endpoint
+ )
self._print_endpoint_info(info_sessions, args.id)
else:
self._print_local_info()
@@ -131,12 +228,14 @@ def spark(self, line, cell="", local_ns=None):
# add
elif subcommand == "add":
if args.url is None:
- self.ipython_display.send_error("Need to supply URL argument (e.g. -u https://example.com/livyendpoint)")
+ self.ipython_display.send_error(
+ "Need to supply URL argument (e.g. -u https://example.com/livyendpoint)"
+ )
return
name = args.session
language = args.language
-
+
endpoint = Endpoint(args.url, initialize_auth(args))
skip = args.skip
@@ -149,13 +248,17 @@ def spark(self, line, cell="", local_ns=None):
self.spark_controller.delete_session_by_name(args.session)
elif args.url is not None:
if args.id is None:
- self.ipython_display.send_error("Must provide --id or -i option to delete session at endpoint from URL")
+ self.ipython_display.send_error(
+ "Must provide --id or -i option to delete session at endpoint from URL"
+ )
return
endpoint = Endpoint(args.url, initialize_auth(args))
session_id = args.id
self.spark_controller.delete_session_by_id(endpoint, session_id)
else:
- self.ipython_display.send_error("Subcommand 'delete' requires a session name or a URL and session ID")
+ self.ipython_display.send_error(
+ "Subcommand 'delete' requires a session name or a URL and session ID"
+ )
# cleanup
elif subcommand == "cleanup":
if args.url is not None:
@@ -170,25 +273,52 @@ def spark(self, line, cell="", local_ns=None):
elif len(subcommand) == 0:
coerce = get_coerce_value(args.coerce)
if args.context == CONTEXT_NAME_SPARK:
- return self.execute_spark(cell, args.output, args.samplemethod,
- args.maxrows, args.samplefraction, args.session, coerce)
+ return self.execute_spark(
+ cell,
+ args.output,
+ args.samplemethod,
+ args.maxrows,
+ args.samplefraction,
+ args.session,
+ coerce,
+ )
elif args.context == CONTEXT_NAME_SQL:
- return self.execute_sqlquery(cell, args.samplemethod, args.maxrows, args.samplefraction,
- args.session, args.output, args.quiet, coerce)
+ return self.execute_sqlquery(
+ cell,
+ args.samplemethod,
+ args.maxrows,
+ args.samplefraction,
+ args.session,
+ args.output,
+ args.quiet,
+ coerce,
+ )
else:
- self.ipython_display.send_error("Context '{}' not found".format(args.context))
+ self.ipython_display.send_error(
+ "Context '{}' not found".format(args.context)
+ )
# error
else:
- self.ipython_display.send_error("Subcommand '{}' not found. {}".format(subcommand, usage))
+ self.ipython_display.send_error(
+ "Subcommand '{}' not found. {}".format(subcommand, usage)
+ )
def _print_local_info(self):
- sessions_info = [" {}".format(i) for i in self.spark_controller.get_manager_sessions_str()]
- print("""Info for running Spark:
+ sessions_info = [
+ " {}".format(i)
+ for i in self.spark_controller.get_manager_sessions_str()
+ ]
+ print(
+ """Info for running Spark:
Sessions:
{}
Session configs:
{}
-""".format("\n".join(sessions_info), conf.session_configs()))
+""".format(
+ "\n".join(sessions_info), conf.session_configs()
+ )
+ )
+
def load_ipython_extension(ip):
ip.register_magics(RemoteSparkMagics)
diff --git a/sparkmagic/sparkmagic/magics/sparkmagicsbase.py b/sparkmagic/sparkmagic/magics/sparkmagicsbase.py
index 5c5f62045..40c6b1802 100644
--- a/sparkmagic/sparkmagic/magics/sparkmagicsbase.py
+++ b/sparkmagic/sparkmagic/magics/sparkmagicsbase.py
@@ -16,48 +16,62 @@
from sparkmagic.utils.sparklogger import SparkLog
from sparkmagic.utils.sparkevents import SparkEvents
from sparkmagic.utils.utils import get_sessions_info_html
-from sparkmagic.utils.constants import MIMETYPE_TEXT_HTML
+from sparkmagic.utils.constants import MIMETYPE_TEXT_HTML
from sparkmagic.livyclientlib.sparkcontroller import SparkController
from sparkmagic.livyclientlib.sqlquery import SQLQuery
from sparkmagic.livyclientlib.command import Command
from sparkmagic.livyclientlib.sparkstorecommand import SparkStoreCommand
from sparkmagic.livyclientlib.exceptions import SparkStatementException
-from sparkmagic.livyclientlib.sendpandasdftosparkcommand import SendPandasDfToSparkCommand
+from sparkmagic.livyclientlib.sendpandasdftosparkcommand import (
+ SendPandasDfToSparkCommand,
+)
from sparkmagic.livyclientlib.sendstringtosparkcommand import SendStringToSparkCommand
from sparkmagic.livyclientlib.exceptions import BadUserDataException
-# How to display different cell content types in IPython
-SparkOutputHandler = namedtuple('SparkOutputHandler', ['html', 'text', 'default'])
+# How to display different cell content types in IPython
+SparkOutputHandler = namedtuple("SparkOutputHandler", ["html", "text", "default"])
@magics_class
class SparkMagicBase(Magics):
- _STRING_VAR_TYPE = 'str'
- _PANDAS_DATAFRAME_VAR_TYPE = 'df'
+ _STRING_VAR_TYPE = "str"
+ _PANDAS_DATAFRAME_VAR_TYPE = "df"
_ALLOWED_LOCAL_TO_SPARK_TYPES = [_STRING_VAR_TYPE, _PANDAS_DATAFRAME_VAR_TYPE]
def __init__(self, shell, data=None, spark_events=None):
# You must call the parent constructor
super(SparkMagicBase, self).__init__(shell)
- self.logger = SparkLog(u"SparkMagics")
+ self.logger = SparkLog("SparkMagics")
self.ipython_display = IpythonDisplay()
self.spark_controller = SparkController(self.ipython_display)
- self.logger.debug(u'Initialized spark magics.')
+ self.logger.debug("Initialized spark magics.")
if spark_events is None:
spark_events = SparkEvents()
spark_events.emit_library_loaded_event()
- def do_send_to_spark(self, cell, input_variable_name, var_type, output_variable_name, max_rows, session_name):
+ def do_send_to_spark(
+ self,
+ cell,
+ input_variable_name,
+ var_type,
+ output_variable_name,
+ max_rows,
+ session_name,
+ ):
try:
input_variable_value = self.shell.user_ns[input_variable_name]
except KeyError:
- raise BadUserDataException(u'Variable named {} not found.'.format(input_variable_name))
+ raise BadUserDataException(
+ "Variable named {} not found.".format(input_variable_name)
+ )
if input_variable_value is None:
- raise BadUserDataException(u'Value of {} is None!'.format(input_variable_name))
+ raise BadUserDataException(
+ "Value of {} is None!".format(input_variable_name)
+ )
if not output_variable_name:
output_variable_name = input_variable_name
@@ -67,26 +81,52 @@ def do_send_to_spark(self, cell, input_variable_name, var_type, output_variable_
input_variable_type = var_type.lower()
if input_variable_type == self._STRING_VAR_TYPE:
- command = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
+ command = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
elif input_variable_type == self._PANDAS_DATAFRAME_VAR_TYPE:
- command = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, max_rows)
+ command = SendPandasDfToSparkCommand(
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ max_rows,
+ )
else:
- raise BadUserDataException(u'Invalid or incorrect -t type. Available are: [{}]'.format(u','.join(self._ALLOWED_LOCAL_TO_SPARK_TYPES)))
+ raise BadUserDataException(
+ "Invalid or incorrect -t type. Available are: [{}]".format(
+ ",".join(self._ALLOWED_LOCAL_TO_SPARK_TYPES)
+ )
+ )
(success, result, mime_type) = self.spark_controller.run_command(command, None)
if not success:
self.ipython_display.send_error(result)
else:
- self.ipython_display.write(u'Successfully passed \'{}\' as \'{}\' to Spark'
- u' kernel'.format(input_variable_name, output_variable_name))
-
- def execute_spark(self, cell, output_var, samplemethod, maxrows, samplefraction, session_name, coerce, output_handler=None):
+ self.ipython_display.write(
+ "Successfully passed '{}' as '{}' to Spark"
+ " kernel".format(input_variable_name, output_variable_name)
+ )
+
+ def execute_spark(
+ self,
+ cell,
+ output_var,
+ samplemethod,
+ maxrows,
+ samplefraction,
+ session_name,
+ coerce,
+ output_handler=None,
+ ):
output_handler = output_handler or SparkOutputHandler(
- html=self.ipython_display.html,
- text=self.ipython_display.write,
- default=self.ipython_display.display)
-
- (success, out, mimetype) = self.spark_controller.run_command(Command(cell), session_name)
+ html=self.ipython_display.html,
+ text=self.ipython_display.write,
+ default=self.ipython_display.display,
+ )
+
+ (success, out, mimetype) = self.spark_controller.run_command(
+ Command(cell), session_name
+ )
if not success:
if conf.shutdown_session_on_spark_statement_errors():
self.spark_controller.cleanup()
@@ -101,16 +141,31 @@ def execute_spark(self, cell, output_var, samplemethod, maxrows, samplefraction,
else:
output_handler.default(out)
if output_var is not None:
- spark_store_command = self._spark_store_command(output_var, samplemethod, maxrows, samplefraction, coerce)
- df = self.spark_controller.run_command(spark_store_command, session_name)
+ spark_store_command = self._spark_store_command(
+ output_var, samplemethod, maxrows, samplefraction, coerce
+ )
+ df = self.spark_controller.run_command(
+ spark_store_command, session_name
+ )
self.shell.user_ns[output_var] = df
@staticmethod
def _spark_store_command(output_var, samplemethod, maxrows, samplefraction, coerce):
- return SparkStoreCommand(output_var, samplemethod, maxrows, samplefraction, coerce=coerce)
-
- def execute_sqlquery(self, cell, samplemethod, maxrows, samplefraction,
- session, output_var, quiet, coerce):
+ return SparkStoreCommand(
+ output_var, samplemethod, maxrows, samplefraction, coerce=coerce
+ )
+
+ def execute_sqlquery(
+ self,
+ cell,
+ samplemethod,
+ maxrows,
+ samplefraction,
+ session,
+ output_var,
+ quiet,
+ coerce,
+ ):
sqlquery = self._sqlquery(cell, samplemethod, maxrows, samplefraction, coerce)
df = self.spark_controller.run_sqlquery(sqlquery, session)
if output_var is not None:
@@ -130,4 +185,4 @@ def _print_endpoint_info(self, info_sessions, current_session_id):
html = get_sessions_info_html(info_sessions, current_session_id)
self.ipython_display.html(html)
else:
- self.ipython_display.html(u'No active sessions.')
+ self.ipython_display.html("No active sessions.")
diff --git a/sparkmagic/sparkmagic/serverextension/handlers.py b/sparkmagic/sparkmagic/serverextension/handlers.py
index c94a258ca..8be29de15 100644
--- a/sparkmagic/sparkmagic/serverextension/handlers.py
+++ b/sparkmagic/sparkmagic/serverextension/handlers.py
@@ -19,7 +19,7 @@ class ReconnectHandler(IPythonHandler):
@web.authenticated
@gen.coroutine
def post(self):
- self.logger = SparkLog(u"ReconnectHandler")
+ self.logger = SparkLog("ReconnectHandler")
spark_events = self._get_spark_events()
@@ -35,13 +35,13 @@ def post(self):
endpoint = None
try:
- path = self._get_argument_or_raise(data, 'path')
- username = self._get_argument_or_raise(data, 'username')
- password = self._get_argument_or_raise(data, 'password')
- endpoint = self._get_argument_or_raise(data, 'endpoint')
- auth = self._get_argument_if_exists(data, 'auth')
+ path = self._get_argument_or_raise(data, "path")
+ username = self._get_argument_or_raise(data, "username")
+ password = self._get_argument_or_raise(data, "password")
+ endpoint = self._get_argument_or_raise(data, "endpoint")
+ auth = self._get_argument_if_exists(data, "auth")
if auth is None:
- if username == '' and password == '':
+ if username == "" and password == "":
auth = constants.NO_AUTH
else:
auth = constants.AUTH_BASIC
@@ -59,7 +59,13 @@ def post(self):
# Execute code
client = kernel_manager.client()
- code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, endpoint, username, password, auth)
+ code = "%{} -s {} -u {} -p {} -t {}".format(
+ KernelMagics._do_not_call_change_endpoint.__name__,
+ endpoint,
+ username,
+ password,
+ auth,
+ )
response_id = client.execute(code, silent=False, store_history=False)
msg = client.get_shell_msg(response_id)
@@ -69,16 +75,20 @@ def post(self):
if successful_message:
status_code = 200
else:
- self.logger.error(u"Code to reconnect errored out: {}".format(error))
+ self.logger.error("Code to reconnect errored out: {}".format(error))
status_code = 500
# Post execution info
self.set_status(status_code)
- self.finish(json.dumps(dict(success=successful_message, error=error), sort_keys=True))
- spark_events.emit_cluster_change_event(endpoint, status_code, successful_message, error)
+ self.finish(
+ json.dumps(dict(success=successful_message, error=error), sort_keys=True)
+ )
+ spark_events.emit_cluster_change_event(
+ endpoint, status_code, successful_message, error
+ )
def _get_kernel_name(self, data):
- kernel_name = self._get_argument_if_exists(data, 'kernelname')
+ kernel_name = self._get_argument_if_exists(data, "kernelname")
self.logger.debug("Kernel name is {}".format(kernel_name))
if kernel_name is None:
kernel_name = conf.server_extension_default_kernel_name()
@@ -100,21 +110,25 @@ def _get_kernel_manager(self, path, kernel_name):
kernel_id = None
for session in sessions:
- if session['notebook']['path'] == path:
- session_id = session['id']
- kernel_id = session['kernel']['id']
- existing_kernel_name = session['kernel']['name']
+ if session["notebook"]["path"] == path:
+ session_id = session["id"]
+ kernel_id = session["kernel"]["id"]
+ existing_kernel_name = session["kernel"]["name"]
break
if kernel_id is None:
- self.logger.debug(u"Kernel not found. Starting a new kernel.")
+ self.logger.debug("Kernel not found. Starting a new kernel.")
k_m = yield self._get_kernel_manager_new_session(path, kernel_name)
elif existing_kernel_name != kernel_name:
- self.logger.debug(u"Existing kernel name '{}' does not match requested '{}'. Starting a new kernel.".format(existing_kernel_name, kernel_name))
+ self.logger.debug(
+ "Existing kernel name '{}' does not match requested '{}'. Starting a new kernel.".format(
+ existing_kernel_name, kernel_name
+ )
+ )
self._delete_session(session_id)
k_m = yield self._get_kernel_manager_new_session(path, kernel_name)
else:
- self.logger.debug(u"Kernel found. Restarting kernel.")
+ self.logger.debug("Kernel found. Restarting kernel.")
k_m = self.kernel_manager.get_kernel(kernel_id)
k_m.restart_kernel()
@@ -122,7 +136,9 @@ def _get_kernel_manager(self, path, kernel_name):
@gen.coroutine
def _get_kernel_manager_new_session(self, path, kernel_name):
- model_future = self.session_manager.create_session(kernel_name=kernel_name, path=path, type="notebook")
+ model_future = self.session_manager.create_session(
+ kernel_name=kernel_name, path=path, type="notebook"
+ )
model = yield model_future
kernel_id = model["kernel"]["id"]
self.logger.debug("Kernel created with id {}".format(str(kernel_id)))
@@ -133,18 +149,18 @@ def _delete_session(self, session_id):
self.session_manager.delete_session(session_id)
def _msg_status(self, msg):
- return msg['content']['status']
+ return msg["content"]["status"]
def _msg_successful(self, msg):
- return self._msg_status(msg) == 'ok'
+ return self._msg_status(msg) == "ok"
def _msg_error(self, msg):
- if self._msg_status(msg) != 'error':
+ if self._msg_status(msg) != "error":
return None
- return u'{}:\n{}'.format(msg['content']['ename'], msg['content']['evalue'])
+ return "{}:\n{}".format(msg["content"]["ename"], msg["content"]["evalue"])
def _get_spark_events(self):
- spark_events = getattr(self, 'spark_events', None)
+ spark_events = getattr(self, "spark_events", None)
if spark_events is None:
return SparkEvents()
return spark_events
@@ -154,10 +170,10 @@ def load_jupyter_server_extension(nb_app):
nb_app.log.info("sparkmagic extension enabled!")
web_app = nb_app.web_app
- base_url = web_app.settings['base_url']
- host_pattern = '.*$'
+ base_url = web_app.settings["base_url"]
+ host_pattern = ".*$"
- route_pattern_reconnect = url_path_join(base_url, '/reconnectsparkmagic')
+ route_pattern_reconnect = url_path_join(base_url, "/reconnectsparkmagic")
handlers = [(route_pattern_reconnect, ReconnectHandler)]
web_app.add_handlers(host_pattern, handlers)
diff --git a/sparkmagic/sparkmagic/tests/test_command.py b/sparkmagic/sparkmagic/tests/test_command.py
index 8e51c535d..5aae1a9b6 100644
--- a/sparkmagic/sparkmagic/tests/test_command.py
+++ b/sparkmagic/sparkmagic/tests/test_command.py
@@ -8,8 +8,13 @@
import sparkmagic.livyclientlib.exceptions
import sparkmagic.utils.configuration as conf
-from sparkmagic.utils.constants import SESSION_KIND_SPARK, MIMETYPE_IMAGE_PNG, MIMETYPE_TEXT_HTML, \
- MIMETYPE_TEXT_PLAIN, COMMAND_INTERRUPTED_MSG
+from sparkmagic.utils.constants import (
+ SESSION_KIND_SPARK,
+ MIMETYPE_IMAGE_PNG,
+ MIMETYPE_TEXT_HTML,
+ MIMETYPE_TEXT_PLAIN,
+ COMMAND_INTERRUPTED_MSG,
+)
from sparkmagic.livyclientlib.command import Command
from sparkmagic.livyclientlib.livysession import LivySession
from sparkmagic.livyclientlib.exceptions import SparkStatementCancelledException
@@ -28,15 +33,21 @@ def _setup():
conf.override_all({})
-def _create_session(kind=SESSION_KIND_SPARK, session_id=-1,
- http_client=None, spark_events=None):
+def _create_session(
+ kind=SESSION_KIND_SPARK, session_id=-1, http_client=None, spark_events=None
+):
if http_client is None:
http_client = MagicMock()
if spark_events is None:
spark_events = MagicMock()
ipython_display = MagicMock()
- session = LivySession(http_client, {"kind": kind, "heartbeatTimeoutInSecond": 60},
- ipython_display, session_id, spark_events)
+ session = LivySession(
+ http_client,
+ {"kind": kind, "heartbeatTimeoutInSecond": 60},
+ ipython_display,
+ session_id,
+ spark_events,
+ )
return session
@@ -69,14 +80,23 @@ def test_execute():
assert result[0]
assert_equals(tls.TestLivySession.pi_result, result[1])
assert_equals(MIMETYPE_TEXT_PLAIN, result[2])
- spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- 0, True, "", "")
+ spark_events.emit_statement_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid, 0, True, "", ""
+ )
# Now try with PNG result:
- http_client.get_statement.return_value = {"id":0,"state":"available","output":{"status":"ok", "execution_count":0,"data":{"text/plain":"", "image/png": b64encode(b"hello")}}}
+ http_client.get_statement.return_value = {
+ "id": 0,
+ "state": "available",
+ "output": {
+ "status": "ok",
+ "execution_count": 0,
+ "data": {"text/plain": "", "image/png": b64encode(b"hello")},
+ },
+ }
result = command.execute(session)
assert result[0]
assert isinstance(result[1], Image)
@@ -84,10 +104,18 @@ def test_execute():
assert_equals(MIMETYPE_IMAGE_PNG, result[2])
# Now try with HTML result:
- http_client.get_statement.return_value = {"id":0,"state":"available","output":{"status":"ok", "execution_count":0,"data":{"text/html":"out
"}}}
+ http_client.get_statement.return_value = {
+ "id": 0,
+ "state": "available",
+ "output": {
+ "status": "ok",
+ "execution_count": 0,
+ "data": {"text/html": "out
"},
+ },
+ }
result = command.execute(session)
assert result[0]
- assert_equals(u"out
", result[1])
+ assert_equals("out
", result[1])
assert_equals(MIMETYPE_TEXT_HTML, result[2])
@@ -99,7 +127,12 @@ def test_execute_waiting():
http_client.post_session.return_value = tls.TestLivySession.session_create_json
http_client.post_statement.return_value = tls.TestLivySession.post_statement_json
http_client.get_session.return_value = tls.TestLivySession.ready_sessions_json
- http_client.get_statement.side_effect = [tls.TestLivySession.waiting_statement_json, tls.TestLivySession.waiting_statement_json, tls.TestLivySession.ready_statement_json, tls.TestLivySession.ready_statement_json]
+ http_client.get_statement.side_effect = [
+ tls.TestLivySession.waiting_statement_json,
+ tls.TestLivySession.waiting_statement_json,
+ tls.TestLivySession.ready_statement_json,
+ tls.TestLivySession.ready_statement_json,
+ ]
session = _create_session(kind=kind, http_client=http_client)
session.start()
command = Command("command", spark_events=spark_events)
@@ -111,11 +144,12 @@ def test_execute_waiting():
assert result[0]
assert_equals(tls.TestLivySession.pi_result, result[1])
assert_equals(MIMETYPE_TEXT_PLAIN, result[2])
- spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- 0, True, "", "")
+ spark_events.emit_statement_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid, 0, True, "", ""
+ )
@with_setup(_setup)
@@ -126,7 +160,9 @@ def test_execute_null_ouput():
http_client.post_session.return_value = tls.TestLivySession.session_create_json
http_client.post_statement.return_value = tls.TestLivySession.post_statement_json
http_client.get_session.return_value = tls.TestLivySession.ready_sessions_json
- http_client.get_statement.return_value = tls.TestLivySession.ready_statement_null_output_json
+ http_client.get_statement.return_value = (
+ tls.TestLivySession.ready_statement_null_output_json
+ )
session = _create_session(kind=kind, http_client=http_client)
session.start()
command = Command("command", spark_events=spark_events)
@@ -136,13 +172,14 @@ def test_execute_null_ouput():
http_client.post_statement.assert_called_with(0, {"code": command.code})
http_client.get_statement.assert_called_with(0, 0)
assert result[0]
- assert_equals(u"", result[1])
+ assert_equals("", result[1])
assert_equals(MIMETYPE_TEXT_PLAIN, result[2])
- spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- 0, True, "", "")
+ spark_events.emit_statement_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid, 0, True, "", ""
+ )
@with_setup(_setup)
@@ -163,11 +200,19 @@ def test_execute_failure_wait_for_session_emits_event():
result = command.execute(session)
assert False
except ValueError as e:
- spark_events.emit_statement_execution_start_event.assert_called_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- -1, False, "ValueError", "yo")
+ spark_events.emit_statement_execution_start_event.assert_called_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event.assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ command.guid,
+ -1,
+ False,
+ "ValueError",
+ "yo",
+ )
assert_equals(e, session.wait_for_idle.side_effect)
@@ -183,17 +228,24 @@ def test_execute_failure_post_statement_emits_event():
session.wait_for_idle = MagicMock()
command = Command("command", spark_events=spark_events)
- http_client.post_statement.side_effect = KeyError('Something bad happened here')
+ http_client.post_statement.side_effect = KeyError("Something bad happened here")
try:
result = command.execute(session)
assert False
except KeyError as e:
- spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event._assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- -1, False, "KeyError",
- "Something bad happened here")
+ spark_events.emit_statement_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event._assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ command.guid,
+ -1,
+ False,
+ "KeyError",
+ "Something bad happened here",
+ )
assert_equals(e, http_client.post_statement.side_effect)
@@ -209,18 +261,25 @@ def test_execute_failure_get_statement_output_emits_event():
session.start()
session.wait_for_idle = MagicMock()
command = Command("command", spark_events=spark_events)
- command._get_statement_output = MagicMock(side_effect=AttributeError('OHHHH'))
+ command._get_statement_output = MagicMock(side_effect=AttributeError("OHHHH"))
try:
result = command.execute(session)
assert False
except AttributeError as e:
- spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event._assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- -1, False, "AttributeError",
- "OHHHH")
+ spark_events.emit_statement_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event._assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ command.guid,
+ -1,
+ False,
+ "AttributeError",
+ "OHHHH",
+ )
assert_equals(e, command._get_statement_output.side_effect)
@@ -245,12 +304,19 @@ def test_execute_interrupted():
result = command.execute(session)
assert False
except KeyboardInterrupt as e:
- spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, command.guid)
- spark_events.emit_statement_execution_end_event._assert_called_once_with(session.guid, session.kind,
- session.id, command.guid,
- -1, False, "KeyboardInterrupt",
- "")
+ spark_events.emit_statement_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, command.guid
+ )
+ spark_events.emit_statement_execution_end_event._assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ command.guid,
+ -1,
+ False,
+ "KeyboardInterrupt",
+ "",
+ )
assert isinstance(e, SparkStatementCancelledException)
assert_equals(str(e), COMMAND_INTERRUPTED_MSG)
@@ -263,9 +329,11 @@ def test_execute_interrupted():
assert not stderr.getvalue()
with _capture_stderr() as stderr:
- mock_ipython._showtraceback(SparkStatementCancelledException, COMMAND_INTERRUPTED_MSG, MagicMock())
- mock_show_tb.assert_called_once() # still once
- assert_equals(stderr.getvalue().strip(), COMMAND_INTERRUPTED_MSG)
+ mock_ipython._showtraceback(
+ SparkStatementCancelledException, COMMAND_INTERRUPTED_MSG, MagicMock()
+ )
+ mock_show_tb.assert_called_once() # still once
+ assert_equals(stderr.getvalue().strip(), COMMAND_INTERRUPTED_MSG)
except:
assert False
diff --git a/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py b/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py
index 680b3303c..383a81318 100644
--- a/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py
+++ b/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py
@@ -14,17 +14,17 @@ def test_with_empty_list():
assert_equals(5, policy.seconds_to_sleep(4))
assert_equals(5, policy.seconds_to_sleep(5))
assert_equals(5, policy.seconds_to_sleep(6))
-
+
# Check based on retry count
assert_equals(True, policy.should_retry(500, False, 0))
assert_equals(True, policy.should_retry(500, False, 4))
assert_equals(True, policy.should_retry(500, False, 5))
assert_equals(False, policy.should_retry(500, False, 6))
-
+
# Check based on status code
assert_equals(False, policy.should_retry(201, False, 0))
assert_equals(False, policy.should_retry(201, False, 6))
-
+
# Check based on error
assert_equals(True, policy.should_retry(201, True, 0))
assert_equals(True, policy.should_retry(201, True, 6))
@@ -45,11 +45,11 @@ def test_with_one_element_list():
assert_equals(True, policy.should_retry(500, False, 4))
assert_equals(True, policy.should_retry(500, False, 5))
assert_equals(False, policy.should_retry(500, False, 6))
-
+
# Check based on status code
assert_equals(False, policy.should_retry(201, False, 0))
assert_equals(False, policy.should_retry(201, False, 6))
-
+
# Check based on error
assert_equals(True, policy.should_retry(201, True, 0))
assert_equals(True, policy.should_retry(201, True, 6))
@@ -76,11 +76,11 @@ def test_with_default_values():
assert_equals(True, policy.should_retry(500, False, 7))
assert_equals(True, policy.should_retry(500, False, 8))
assert_equals(False, policy.should_retry(500, False, 9))
-
+
# Check based on status code
assert_equals(False, policy.should_retry(201, False, 0))
assert_equals(False, policy.should_retry(201, False, 9))
-
+
# Check based on error
assert_equals(True, policy.should_retry(201, True, 0))
assert_equals(True, policy.should_retry(201, True, 9))
@@ -89,7 +89,7 @@ def test_with_default_values():
def test_with_negative_values():
times = [0.1, -1]
max_retries = 5
-
+
try:
policy = ConfigurableRetryPolicy(times, max_retries)
assert False
diff --git a/sparkmagic/sparkmagic/tests/test_configuration.py b/sparkmagic/sparkmagic/tests/test_configuration.py
index 52ca0e2f5..aff41a3da 100644
--- a/sparkmagic/sparkmagic/tests/test_configuration.py
+++ b/sparkmagic/sparkmagic/tests/test_configuration.py
@@ -9,65 +9,104 @@
def _setup():
conf.override_all({})
-
+
@with_setup(_setup)
def test_configuration_override_base64_password():
- kpc = { 'username': 'U', 'password': 'P', 'base64_password': 'cGFzc3dvcmQ=', 'url': 'L', "auth": AUTH_BASIC }
- overrides = { conf.kernel_python_credentials.__name__: kpc }
+ kpc = {
+ "username": "U",
+ "password": "P",
+ "base64_password": "cGFzc3dvcmQ=",
+ "url": "L",
+ "auth": AUTH_BASIC,
+ }
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
conf.override_all(overrides)
conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1)
- assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc,
- conf.livy_session_startup_timeout_seconds.__name__: 1 })
+ assert_equals(
+ conf.d,
+ {
+ conf.kernel_python_credentials.__name__: kpc,
+ conf.livy_session_startup_timeout_seconds.__name__: 1,
+ },
+ )
assert_equals(conf.livy_session_startup_timeout_seconds(), 1)
- assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': 'password', 'url': 'L', 'auth': AUTH_BASIC })
+ assert_equals(
+ conf.base64_kernel_python_credentials(),
+ {"username": "U", "password": "password", "url": "L", "auth": AUTH_BASIC},
+ )
@with_setup(_setup)
def test_configuration_auth_missing_basic_auth():
- kpc = { 'username': 'U', 'password': 'P', 'url': 'L'}
- overrides = { conf.kernel_python_credentials.__name__: kpc }
+ kpc = {"username": "U", "password": "P", "url": "L"}
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
conf.override_all(overrides)
- assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': 'P', 'url': 'L', 'auth': AUTH_BASIC })
+ assert_equals(
+ conf.base64_kernel_python_credentials(),
+ {"username": "U", "password": "P", "url": "L", "auth": AUTH_BASIC},
+ )
@with_setup(_setup)
def test_configuration_auth_missing_no_auth():
- kpc = { 'username': '', 'password': '', 'url': 'L'}
- overrides = { conf.kernel_python_credentials.__name__: kpc }
+ kpc = {"username": "", "password": "", "url": "L"}
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
conf.override_all(overrides)
- assert_equals(conf.base64_kernel_python_credentials(), { 'username': '', 'password': '', 'url': 'L', 'auth': NO_AUTH })
+ assert_equals(
+ conf.base64_kernel_python_credentials(),
+ {"username": "", "password": "", "url": "L", "auth": NO_AUTH},
+ )
@with_setup(_setup)
def test_configuration_override_fallback_to_password():
- kpc = { 'username': 'U', 'password': 'P', 'url': 'L', 'auth': NO_AUTH }
- overrides = { conf.kernel_python_credentials.__name__: kpc }
+ kpc = {"username": "U", "password": "P", "url": "L", "auth": NO_AUTH}
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
conf.override_all(overrides)
conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1)
- assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc,
- conf.livy_session_startup_timeout_seconds.__name__: 1 })
+ assert_equals(
+ conf.d,
+ {
+ conf.kernel_python_credentials.__name__: kpc,
+ conf.livy_session_startup_timeout_seconds.__name__: 1,
+ },
+ )
assert_equals(conf.livy_session_startup_timeout_seconds(), 1)
assert_equals(conf.base64_kernel_python_credentials(), kpc)
@with_setup(_setup)
def test_configuration_override_work_with_empty_password():
- kpc = { 'username': 'U', 'base64_password': '', 'password': '', 'url': '', 'auth': AUTH_BASIC }
- overrides = { conf.kernel_python_credentials.__name__: kpc }
+ kpc = {
+ "username": "U",
+ "base64_password": "",
+ "password": "",
+ "url": "",
+ "auth": AUTH_BASIC,
+ }
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
conf.override_all(overrides)
conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1)
- assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc,
- conf.livy_session_startup_timeout_seconds.__name__: 1 })
+ assert_equals(
+ conf.d,
+ {
+ conf.kernel_python_credentials.__name__: kpc,
+ conf.livy_session_startup_timeout_seconds.__name__: 1,
+ },
+ )
assert_equals(conf.livy_session_startup_timeout_seconds(), 1)
- assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': '', 'url': '', 'auth': AUTH_BASIC })
+ assert_equals(
+ conf.base64_kernel_python_credentials(),
+ {"username": "U", "password": "", "url": "", "auth": AUTH_BASIC},
+ )
@raises(BadUserConfigurationException)
@with_setup(_setup)
def test_configuration_raise_error_for_bad_base64_password():
- kpc = { 'username': 'U', 'base64_password': 'P', 'url': 'L' }
- overrides = { conf.kernel_python_credentials.__name__: kpc }
+ kpc = {"username": "U", "base64_password": "P", "url": "L"}
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
conf.override_all(overrides)
conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1)
conf.base64_kernel_python_credentials()
@@ -75,6 +114,15 @@ def test_configuration_raise_error_for_bad_base64_password():
@with_setup(_setup)
def test_share_config_between_pyspark_and_pyspark3():
- kpc = { 'username': 'U', 'password': 'P', 'base64_password': 'cGFzc3dvcmQ=', 'url': 'L', 'auth': AUTH_BASIC }
- overrides = { conf.kernel_python_credentials.__name__: kpc }
- assert_equals(conf.base64_kernel_python3_credentials(), conf.base64_kernel_python_credentials())
+ kpc = {
+ "username": "U",
+ "password": "P",
+ "base64_password": "cGFzc3dvcmQ=",
+ "url": "L",
+ "auth": AUTH_BASIC,
+ }
+ overrides = {conf.kernel_python_credentials.__name__: kpc}
+ assert_equals(
+ conf.base64_kernel_python3_credentials(),
+ conf.base64_kernel_python_credentials(),
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_endpoint.py b/sparkmagic/sparkmagic/tests/test_endpoint.py
index c3491cf84..1a2259d64 100644
--- a/sparkmagic/sparkmagic/tests/test_endpoint.py
+++ b/sparkmagic/sparkmagic/tests/test_endpoint.py
@@ -5,20 +5,30 @@
from sparkmagic.auth.basic import Basic
from sparkmagic.auth.kerberos import Kerberos
+
def test_equality():
basic_auth1 = Basic()
basic_auth2 = Basic()
kerberos_auth1 = Kerberos()
kerberos_auth2 = Kerberos()
- assert_equals(Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2))
- assert_equals(Endpoint("http://url.com", kerberos_auth1), Endpoint("http://url.com", kerberos_auth2))
+ assert_equals(
+ Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2)
+ )
+ assert_equals(
+ Endpoint("http://url.com", kerberos_auth1),
+ Endpoint("http://url.com", kerberos_auth2),
+ )
+
def test_inequality():
basic_auth1 = Basic()
basic_auth2 = Basic()
- basic_auth1.username = 'user'
- basic_auth2.username = 'different_user'
- assert_not_equal(Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2))
+ basic_auth1.username = "user"
+ basic_auth2.username = "different_user"
+ assert_not_equal(
+ Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2)
+ )
+
def test_invalid_url():
basic_auth = Basic()
diff --git a/sparkmagic/sparkmagic/tests/test_exceptions.py b/sparkmagic/sparkmagic/tests/test_exceptions.py
index a0ce92300..db2e91334 100644
--- a/sparkmagic/sparkmagic/tests/test_exceptions.py
+++ b/sparkmagic/sparkmagic/tests/test_exceptions.py
@@ -21,7 +21,7 @@ def _setup():
@with_setup(_setup)
def test_handle_expected_exceptions():
mock_method = MagicMock()
- mock_method.__name__ = 'MockMethod'
+ mock_method.__name__ = "MockMethod"
decorated = handle_expected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
@@ -33,51 +33,48 @@ def test_handle_expected_exceptions():
@with_setup(_setup)
def test_handle_expected_exceptions_handle():
- conf.override_all({
- 'all_errors_are_fatal': False
- })
- mock_method = MagicMock(side_effect=LivyUnexpectedStatusException('ridiculous'))
- mock_method.__name__ = 'MockMethod2'
+ conf.override_all({"all_errors_are_fatal": False})
+ mock_method = MagicMock(side_effect=LivyUnexpectedStatusException("ridiculous"))
+ mock_method.__name__ = "MockMethod2"
decorated = handle_expected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
- result = decorated(self, 1, kwarg='foo')
+ result = decorated(self, 1, kwarg="foo")
assert_is(result, None)
assert_equals(ipython_display.send_error.call_count, 1)
- mock_method.assert_called_once_with(self, 1, kwarg='foo')
+ mock_method.assert_called_once_with(self, 1, kwarg="foo")
@raises(ValueError)
@with_setup(_setup)
def test_handle_expected_exceptions_throw():
- mock_method = MagicMock(side_effect=ValueError('HALP'))
- mock_method.__name__ = 'mock_meth'
+ mock_method = MagicMock(side_effect=ValueError("HALP"))
+ mock_method.__name__ = "mock_meth"
decorated = handle_expected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
- result = decorated(self, 1, kwarg='foo')
+ result = decorated(self, 1, kwarg="foo")
@raises(LivyUnexpectedStatusException)
@with_setup(_setup)
def test_handle_expected_exceptions_throws_if_all_errors_fatal():
- conf.override_all({
- 'all_errors_are_fatal': True
- })
- mock_method = MagicMock(side_effect=LivyUnexpectedStatusException('Oh no!'))
- mock_method.__name__ = 'mock_meth'
+ conf.override_all({"all_errors_are_fatal": True})
+ mock_method = MagicMock(side_effect=LivyUnexpectedStatusException("Oh no!"))
+ mock_method.__name__ = "mock_meth"
decorated = handle_expected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
- result = decorated(self, 1, kwarg='foo')
+ result = decorated(self, 1, kwarg="foo")
# test wrap with unexpected to true
+
@with_setup(_setup)
def test_wrap_unexpected_exceptions():
mock_method = MagicMock()
- mock_method.__name__ = 'tos'
+ mock_method.__name__ = "tos"
decorated = wrap_unexpected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
@@ -89,27 +86,26 @@ def test_wrap_unexpected_exceptions():
@with_setup(_setup)
def test_wrap_unexpected_exceptions_handle():
- mock_method = MagicMock(side_effect=ValueError('~~~~~~'))
- mock_method.__name__ = 'tos'
+ mock_method = MagicMock(side_effect=ValueError("~~~~~~"))
+ mock_method.__name__ = "tos"
decorated = wrap_unexpected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
- result = decorated(self, 'FOOBAR', FOOBAR='FOOBAR')
+ result = decorated(self, "FOOBAR", FOOBAR="FOOBAR")
assert_is(result, None)
assert_equals(ipython_display.send_error.call_count, 1)
- mock_method.assert_called_once_with(self, 'FOOBAR', FOOBAR='FOOBAR')
+ mock_method.assert_called_once_with(self, "FOOBAR", FOOBAR="FOOBAR")
+
# test wrap with unexpected to true
# test wrap with all to true
@raises(ValueError)
@with_setup(_setup)
def test_wrap_unexpected_exceptions_throws_if_all_errors_fatal():
- conf.override_all({
- 'all_errors_are_fatal': True
- })
- mock_method = MagicMock(side_effect=ValueError('~~~~~~'))
- mock_method.__name__ = 'tos'
+ conf.override_all({"all_errors_are_fatal": True})
+ mock_method = MagicMock(side_effect=ValueError("~~~~~~"))
+ mock_method.__name__ = "tos"
decorated = wrap_unexpected_exceptions(mock_method)
assert_equals(decorated.__name__, mock_method.__name__)
- result = decorated(self, 'FOOBAR', FOOBAR='FOOBAR')
+ result = decorated(self, "FOOBAR", FOOBAR="FOOBAR")
diff --git a/sparkmagic/sparkmagic/tests/test_handlers.py b/sparkmagic/sparkmagic/tests/test_handlers.py
index 4d01734fd..fb441d88c 100644
--- a/sparkmagic/sparkmagic/tests/test_handlers.py
+++ b/sparkmagic/sparkmagic/tests/test_handlers.py
@@ -23,24 +23,33 @@ class TestSparkMagicHandler(AsyncTestCase):
client = None
session_list = None
spark_events = None
- path = 'some_path.ipynb'
- kernel_id = '1'
- kernel_name = 'pysparkkernel'
- session_id = '1'
- username = 'username'
- password = 'password'
- endpoint = 'http://endpoint.com'
+ path = "some_path.ipynb"
+ kernel_id = "1"
+ kernel_name = "pysparkkernel"
+ session_id = "1"
+ username = "username"
+ password = "password"
+ endpoint = "http://endpoint.com"
auth = constants.AUTH_BASIC
- response_id = '0'
- good_msg = dict(content=dict(status='ok'))
- bad_msg = dict(content=dict(status='error', ename='SyntaxError', evalue='oh no!'))
+ response_id = "0"
+ good_msg = dict(content=dict(status="ok"))
+ bad_msg = dict(content=dict(status="error", ename="SyntaxError", evalue="oh no!"))
request = None
def create_session_dict(self, path, kernel_id):
- return dict(notebook=dict(path=path), kernel=dict(id=kernel_id, name=self.kernel_name), id=self.session_id)
+ return dict(
+ notebook=dict(path=path),
+ kernel=dict(id=kernel_id, name=self.kernel_name),
+ id=self.session_id,
+ )
def get_argument(self, key):
- return dict(username=self.username, password=self.password, endpoint=self.endpoint, path=self.path)[key]
+ return dict(
+ username=self.username,
+ password=self.password,
+ endpoint=self.endpoint,
+ path=self.path,
+ )[key]
def setUp(self):
# Mock kernel manager
@@ -50,20 +59,32 @@ def setUp(self):
self.individual_kernel_manager = MagicMock()
self.individual_kernel_manager.client = MagicMock(return_value=self.client)
self.kernel_manager = MagicMock()
- self.kernel_manager.get_kernel = MagicMock(return_value=self.individual_kernel_manager)
+ self.kernel_manager.get_kernel = MagicMock(
+ return_value=self.individual_kernel_manager
+ )
# Mock session manager
self.session_list = [self.create_session_dict(self.path, self.kernel_id)]
self.session_manager = MagicMock()
self.session_manager.list_sessions = MagicMock(return_value=self.session_list)
- self.session_manager.create_session = MagicMock(return_value=self.create_session_dict(self.path, self.kernel_id))
+ self.session_manager.create_session = MagicMock(
+ return_value=self.create_session_dict(self.path, self.kernel_id)
+ )
# Mock spark events
self.spark_events = MagicMock()
# Mock request
self.request = MagicMock()
- self.request.body = json.dumps({"path": self.path, "username": self.username, "password": self.password, "endpoint": self.endpoint, "auth": self.auth})
+ self.request.body = json.dumps(
+ {
+ "path": self.path,
+ "username": self.username,
+ "password": self.password,
+ "endpoint": self.endpoint,
+ "auth": self.auth,
+ }
+ )
# Create mocked reconnect_handler
ReconnectHandler.__bases__ = (SimpleObject,)
@@ -73,15 +94,15 @@ def setUp(self):
self.reconnect_handler.kernel_manager = self.kernel_manager
self.reconnect_handler.set_status = MagicMock()
self.reconnect_handler.finish = MagicMock()
- self.reconnect_handler.current_user = 'alex'
+ self.reconnect_handler.current_user = "alex"
self.reconnect_handler.request = self.request
self.reconnect_handler.logger = MagicMock()
super(TestSparkMagicHandler, self).setUp()
def test_msg_status(self):
- assert_equals(self.reconnect_handler._msg_status(self.good_msg), 'ok')
- assert_equals(self.reconnect_handler._msg_status(self.bad_msg), 'error')
+ assert_equals(self.reconnect_handler._msg_status(self.good_msg), "ok")
+ assert_equals(self.reconnect_handler._msg_status(self.bad_msg), "error")
def test_msg_successful(self):
assert_equals(self.reconnect_handler._msg_successful(self.good_msg), True)
@@ -89,7 +110,10 @@ def test_msg_successful(self):
def test_msg_error(self):
assert_equals(self.reconnect_handler._msg_error(self.good_msg), None)
- assert_equals(self.reconnect_handler._msg_error(self.bad_msg), u'{}:\n{}'.format('SyntaxError', 'oh no!'))
+ assert_equals(
+ self.reconnect_handler._msg_error(self.bad_msg),
+ "{}:\n{}".format("SyntaxError", "oh no!"),
+ )
@gen_test
def test_post_no_json(self):
@@ -101,7 +125,9 @@ def test_post_no_json(self):
msg = "Invalid JSON in request body."
self.reconnect_handler.set_status.assert_called_once_with(400)
self.reconnect_handler.finish.assert_called_once_with(msg)
- self.spark_events.emit_cluster_change_event.assert_called_once_with(None, 400, False, msg)
+ self.spark_events.emit_cluster_change_event.assert_called_once_with(
+ None, 400, False, msg
+ )
@gen_test
def test_post_no_key(self):
@@ -110,15 +136,24 @@ def test_post_no_key(self):
res = yield self.reconnect_handler.post()
assert_equals(res, None)
- msg = 'HTTP 400: Bad Request (Missing argument path)'
+ msg = "HTTP 400: Bad Request (Missing argument path)"
self.reconnect_handler.set_status.assert_called_once_with(400)
self.reconnect_handler.finish.assert_called_once_with(msg)
- self.spark_events.emit_cluster_change_event.assert_called_once_with(None, 400, False, msg)
+ self.spark_events.emit_cluster_change_event.assert_called_once_with(
+ None, 400, False, msg
+ )
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager')
+ @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager")
@gen_test
def test_post_existing_kernel_with_auth_missing_no_auth(self, _get_kernel_manager):
- self.request.body = json.dumps({ "path": self.path, "username": '', "password": '', "endpoint": self.endpoint })
+ self.request.body = json.dumps(
+ {
+ "path": self.path,
+ "username": "",
+ "password": "",
+ "endpoint": self.endpoint,
+ }
+ )
kernel_manager_future = Future()
kernel_manager_future.set_result(self.individual_kernel_manager)
_get_kernel_manager.return_value = kernel_manager_future
@@ -126,16 +161,37 @@ def test_post_existing_kernel_with_auth_missing_no_auth(self, _get_kernel_manage
res = yield self.reconnect_handler.post()
assert_equals(res, None)
- code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, '', '', constants.NO_AUTH)
- self.client.execute.assert_called_once_with(code, silent=False, store_history=False)
+ code = "%{} -s {} -u {} -p {} -t {}".format(
+ KernelMagics._do_not_call_change_endpoint.__name__,
+ self.endpoint,
+ "",
+ "",
+ constants.NO_AUTH,
+ )
+ self.client.execute.assert_called_once_with(
+ code, silent=False, store_history=False
+ )
self.reconnect_handler.set_status.assert_called_once_with(200)
- self.reconnect_handler.finish.assert_called_once_with('{"error": null, "success": true}')
- self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 200, True, None)
-
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager')
+ self.reconnect_handler.finish.assert_called_once_with(
+ '{"error": null, "success": true}'
+ )
+ self.spark_events.emit_cluster_change_event.assert_called_once_with(
+ self.endpoint, 200, True, None
+ )
+
+ @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager")
@gen_test
- def test_post_existing_kernel_with_auth_missing_basic_auth(self, _get_kernel_manager):
- self.request.body = json.dumps({ "path": self.path, "username": self.username, "password": self.password, "endpoint": self.endpoint})
+ def test_post_existing_kernel_with_auth_missing_basic_auth(
+ self, _get_kernel_manager
+ ):
+ self.request.body = json.dumps(
+ {
+ "path": self.path,
+ "username": self.username,
+ "password": self.password,
+ "endpoint": self.endpoint,
+ }
+ )
kernel_manager_future = Future()
kernel_manager_future.set_result(self.individual_kernel_manager)
_get_kernel_manager.return_value = kernel_manager_future
@@ -143,13 +199,25 @@ def test_post_existing_kernel_with_auth_missing_basic_auth(self, _get_kernel_man
res = yield self.reconnect_handler.post()
assert_equals(res, None)
- code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, self.username, self.password, constants.AUTH_BASIC)
- self.client.execute.assert_called_once_with(code, silent=False, store_history=False)
+ code = "%{} -s {} -u {} -p {} -t {}".format(
+ KernelMagics._do_not_call_change_endpoint.__name__,
+ self.endpoint,
+ self.username,
+ self.password,
+ constants.AUTH_BASIC,
+ )
+ self.client.execute.assert_called_once_with(
+ code, silent=False, store_history=False
+ )
self.reconnect_handler.set_status.assert_called_once_with(200)
- self.reconnect_handler.finish.assert_called_once_with('{"error": null, "success": true}')
- self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 200, True, None)
-
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager')
+ self.reconnect_handler.finish.assert_called_once_with(
+ '{"error": null, "success": true}'
+ )
+ self.spark_events.emit_cluster_change_event.assert_called_once_with(
+ self.endpoint, 200, True, None
+ )
+
+ @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager")
@gen_test
def test_post_existing_kernel(self, _get_kernel_manager):
kernel_manager_future = Future()
@@ -159,13 +227,25 @@ def test_post_existing_kernel(self, _get_kernel_manager):
res = yield self.reconnect_handler.post()
assert_equals(res, None)
- code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, self.username, self.password, self.auth)
- self.client.execute.assert_called_once_with(code, silent=False, store_history=False)
+ code = "%{} -s {} -u {} -p {} -t {}".format(
+ KernelMagics._do_not_call_change_endpoint.__name__,
+ self.endpoint,
+ self.username,
+ self.password,
+ self.auth,
+ )
+ self.client.execute.assert_called_once_with(
+ code, silent=False, store_history=False
+ )
self.reconnect_handler.set_status.assert_called_once_with(200)
- self.reconnect_handler.finish.assert_called_once_with('{"error": null, "success": true}')
- self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 200, True, None)
-
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager')
+ self.reconnect_handler.finish.assert_called_once_with(
+ '{"error": null, "success": true}'
+ )
+ self.spark_events.emit_cluster_change_event.assert_called_once_with(
+ self.endpoint, 200, True, None
+ )
+
+ @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager")
@gen_test
def test_post_existing_kernel_failed(self, _get_kernel_manager):
kernel_manager_future = Future()
@@ -176,48 +256,80 @@ def test_post_existing_kernel_failed(self, _get_kernel_manager):
res = yield self.reconnect_handler.post()
assert_equals(res, None)
- code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, self.username, self.password, self.auth)
- self.client.execute.assert_called_once_with(code, silent=False, store_history=False)
+ code = "%{} -s {} -u {} -p {} -t {}".format(
+ KernelMagics._do_not_call_change_endpoint.__name__,
+ self.endpoint,
+ self.username,
+ self.password,
+ self.auth,
+ )
+ self.client.execute.assert_called_once_with(
+ code, silent=False, store_history=False
+ )
self.reconnect_handler.set_status.assert_called_once_with(500)
- self.reconnect_handler.finish.assert_called_once_with('{"error": "SyntaxError:\\noh no!", "success": false}')
- self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 500, False, "SyntaxError:\noh no!")
-
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session')
+ self.reconnect_handler.finish.assert_called_once_with(
+ '{"error": "SyntaxError:\\noh no!", "success": false}'
+ )
+ self.spark_events.emit_cluster_change_event.assert_called_once_with(
+ self.endpoint, 500, False, "SyntaxError:\noh no!"
+ )
+
+ @patch(
+ "sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session"
+ )
@gen_test
- def test_get_kernel_manager_no_existing_kernel(self, _get_kernel_manager_new_session):
+ def test_get_kernel_manager_no_existing_kernel(
+ self, _get_kernel_manager_new_session
+ ):
different_path = "different_path.ipynb"
km_future = Future()
km_future.set_result(self.individual_kernel_manager)
_get_kernel_manager_new_session.return_value = km_future
-
- km = yield self.reconnect_handler._get_kernel_manager(different_path, self.kernel_name)
+
+ km = yield self.reconnect_handler._get_kernel_manager(
+ different_path, self.kernel_name
+ )
assert_equals(self.individual_kernel_manager, km)
self.individual_kernel_manager.restart_kernel.assert_not_called()
self.kernel_manager.get_kernel.assert_not_called()
- _get_kernel_manager_new_session.assert_called_once_with(different_path, self.kernel_name)
+ _get_kernel_manager_new_session.assert_called_once_with(
+ different_path, self.kernel_name
+ )
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session')
+ @patch(
+ "sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session"
+ )
@gen_test
def test_get_kernel_manager_existing_kernel(self, _get_kernel_manager_new_session):
- km = yield self.reconnect_handler._get_kernel_manager(self.path, self.kernel_name)
+ km = yield self.reconnect_handler._get_kernel_manager(
+ self.path, self.kernel_name
+ )
assert_equals(self.individual_kernel_manager, km)
self.individual_kernel_manager.restart_kernel.assert_called_once_with()
_get_kernel_manager_new_session.assert_not_called()
- @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session')
+ @patch(
+ "sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session"
+ )
@gen_test
- def test_get_kernel_manager_different_kernel_type(self, _get_kernel_manager_new_session):
+ def test_get_kernel_manager_different_kernel_type(
+ self, _get_kernel_manager_new_session
+ ):
different_kernel = "sparkkernel"
km_future = Future()
km_future.set_result(self.individual_kernel_manager)
_get_kernel_manager_new_session.return_value = km_future
- km = yield self.reconnect_handler._get_kernel_manager(self.path, different_kernel)
+ km = yield self.reconnect_handler._get_kernel_manager(
+ self.path, different_kernel
+ )
assert_equals(self.individual_kernel_manager, km)
self.individual_kernel_manager.restart_kernel.assert_not_called()
self.kernel_manager.get_kernel.assert_not_called()
- _get_kernel_manager_new_session.assert_called_once_with(self.path, different_kernel)
+ _get_kernel_manager_new_session.assert_called_once_with(
+ self.path, different_kernel
+ )
self.session_manager.delete_session.assert_called_once_with(self.session_id)
diff --git a/sparkmagic/sparkmagic/tests/test_heartbeatthread.py b/sparkmagic/sparkmagic/tests/test_heartbeatthread.py
index 4bdfc6bc5..df970f080 100644
--- a/sparkmagic/sparkmagic/tests/test_heartbeatthread.py
+++ b/sparkmagic/sparkmagic/tests/test_heartbeatthread.py
@@ -10,40 +10,40 @@ def test_create_thread():
refresh_seconds = 1
retry_seconds = 2
heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds)
-
+
assert_equals(heartbeat_thread.livy_session, session)
assert_equals(heartbeat_thread.refresh_seconds, refresh_seconds)
assert_equals(heartbeat_thread.retry_seconds, retry_seconds)
-
-
+
+
def test_run_once():
session = MagicMock()
refresh_seconds = 0.1
retry_seconds = 2
heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds, 1)
-
+
heartbeat_thread.start()
sleep(0.15)
heartbeat_thread.stop()
-
+
session.refresh_status_and_info.assert_called_once_with()
assert heartbeat_thread.livy_session is None
-
-
+
+
def test_run_stops():
session = MagicMock()
refresh_seconds = 0.01
retry_seconds = 2
heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds)
-
+
heartbeat_thread.start()
sleep(0.1)
heartbeat_thread.stop()
-
+
assert session.refresh_status_and_info.called
assert heartbeat_thread.livy_session is None
-
-
+
+
def test_run_retries():
msg = "oh noes!"
session = MagicMock()
@@ -51,16 +51,16 @@ def test_run_retries():
refresh_seconds = 0.1
retry_seconds = 0.1
heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds, 1)
-
+
heartbeat_thread.start()
sleep(0.15)
heartbeat_thread.stop()
-
+
session.refresh_status_and_info.assert_called_once_with()
session.logger.error.assert_called_once_with(msg)
assert heartbeat_thread.livy_session is None
-
-
+
+
def test_run_retries_stops():
msg = "oh noes!"
session = MagicMock()
@@ -68,12 +68,11 @@ def test_run_retries_stops():
refresh_seconds = 0.01
retry_seconds = 0.01
heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds)
-
+
heartbeat_thread.start()
sleep(0.1)
heartbeat_thread.stop()
-
+
assert session.refresh_status_and_info.called
assert session.logger.error.called
assert heartbeat_thread.livy_session is None
-
\ No newline at end of file
diff --git a/sparkmagic/sparkmagic/tests/test_kernel_magics.py b/sparkmagic/sparkmagic/tests/test_kernel_magics.py
index 98cd208e7..5a15bc910 100644
--- a/sparkmagic/sparkmagic/tests/test_kernel_magics.py
+++ b/sparkmagic/sparkmagic/tests/test_kernel_magics.py
@@ -7,9 +7,16 @@
import sparkmagic.utils.constants as constants
from sparkmagic.kernels.kernelmagics import KernelMagics, Namespace
from sparkmagic.magics.remotesparkmagics import RemoteSparkMagics
-from sparkmagic.livyclientlib.exceptions import LivyClientTimeoutException, BadUserDataException,\
- LivyUnexpectedStatusException, SessionManagementException,\
- HttpClientException, DataFrameParseException, SqlContextNotFoundException, SparkStatementException
+from sparkmagic.livyclientlib.exceptions import (
+ LivyClientTimeoutException,
+ BadUserDataException,
+ LivyUnexpectedStatusException,
+ SessionManagementException,
+ HttpClientException,
+ DataFrameParseException,
+ SqlContextNotFoundException,
+ SparkStatementException,
+)
from sparkmagic.livyclientlib.endpoint import Endpoint
from sparkmagic.livyclientlib.command import Command
from sparkmagic.auth.basic import Basic
@@ -21,6 +28,7 @@
ipython_display = MagicMock()
spark_events = None
+
@magics_class
class TestKernelMagics(KernelMagics):
def __init__(self, shell, data=None, spark_events=None):
@@ -43,7 +51,7 @@ def _setup():
magic.shell = shell = MagicMock()
magic.ipython_display = ipython_display = MagicMock()
magic.spark_controller = spark_controller = MagicMock()
- magic._generate_uuid = MagicMock(return_value='0000')
+ magic._generate_uuid = MagicMock(return_value="0000")
def _teardown():
@@ -65,8 +73,12 @@ def test_start_session():
assert ret
assert magic.session_started
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
# Call a second time
ret = magic._do_not_call_start_session(line)
@@ -95,12 +107,11 @@ def test_start_session_times_out():
assert not magic.session_started
assert_equals(ipython_display.send_error.call_count, 1)
+
@with_setup(_setup, _teardown)
@raises(LivyClientTimeoutException)
def test_start_session_times_out_all_errors_are_fatal():
- conf.override_all({
- "all_errors_are_fatal": True
- })
+ conf.override_all({"all_errors_are_fatal": True})
line = ""
spark_controller.add_session = MagicMock(side_effect=LivyClientTimeoutException)
@@ -129,14 +140,17 @@ def test_delete_session():
def test_delete_session_expected_exception():
line = ""
magic.session_started = True
- spark_controller.delete_session_by_name.side_effect = BadUserDataException('hey')
+ spark_controller.delete_session_by_name.side_effect = BadUserDataException("hey")
magic._do_not_call_delete_session(line)
assert not magic.session_started
spark_controller.delete_session_by_name.assert_called_once_with(magic.session_name)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
- .format(spark_controller.delete_session_by_name.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(
+ spark_controller.delete_session_by_name.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
@@ -174,30 +188,35 @@ def test_change_language_not_valid():
assert_equals(constants.LANG_PYTHON, magic.language)
assert_equals(Endpoint("url", None), magic.endpoint)
+
@with_setup(_setup, _teardown)
def test_change_endpoint():
- s = 'server'
- u = 'user'
- p = 'password'
+ s = "server"
+ u = "user"
+ p = "password"
t = constants.AUTH_BASIC
line = "-s {} -u {} -p {} -t {}".format(s, u, p, t)
magic._do_not_call_change_endpoint(line)
- args = Namespace(auth='Basic_Access', password='password', url='server', user='user')
+ args = Namespace(
+ auth="Basic_Access", password="password", url="server", user="user"
+ )
auth_instance = initialize_auth(args)
endpoint = Endpoint(s, auth_instance)
assert_equals(endpoint.url, magic.endpoint.url)
assert_equals(Endpoint(s, auth_instance), magic.endpoint)
+
@with_setup(_setup, _teardown)
@raises(BadUserDataException)
def test_change_endpoint_session_started():
- u = 'user'
- p = 'password'
- s = 'server'
+ u = "user"
+ p = "password"
+ s = "server"
line = "-s {} -u {} -p {}".format(s, u, p)
magic.session_started = True
magic._do_not_call_change_endpoint(line)
+
@with_setup(_setup, _teardown)
def test_info():
magic._print_endpoint_info = print_info_mock = MagicMock()
@@ -208,10 +227,15 @@ def test_info():
magic.info(line)
- print_info_mock.assert_called_once_with(session_info, spark_controller.get_session_id_for_client.return_value)
- spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name)
+ print_info_mock.assert_called_once_with(
+ session_info, spark_controller.get_session_id_for_client.return_value
+ )
+ spark_controller.get_session_id_for_client.assert_called_once_with(
+ magic.session_name
+ )
+
+ _assert_magic_successful_event_emitted_once("info")
- _assert_magic_successful_event_emitted_once('info')
@with_setup(_setup, _teardown)
def test_info_without_active_session():
@@ -224,22 +248,25 @@ def test_info_without_active_session():
print_info_mock.assert_called_once_with(session_info, None)
- _assert_magic_successful_event_emitted_once('info')
+ _assert_magic_successful_event_emitted_once("info")
+
@with_setup(_setup, _teardown)
def test_info_with_cell_content():
magic._print_endpoint_info = print_info_mock = MagicMock()
line = ""
session_info = ["1", "2"]
- spark_controller.get_all_sessions_endpoint_info = MagicMock(return_value=session_info)
+ spark_controller.get_all_sessions_endpoint_info = MagicMock(
+ return_value=session_info
+ )
error_msg = "Cell body for %%info magic must be empty; got 'howdy' instead"
- magic.info(line, cell='howdy')
+ magic.info(line, cell="howdy")
print_info_mock.assert_not_called()
assert_equals(ipython_display.send_error.call_count, 1)
spark_controller.get_session_id_for_client.assert_not_called()
- _assert_magic_failure_event_emitted_once('info', BadUserDataException(error_msg))
+ _assert_magic_failure_event_emitted_once("info", BadUserDataException(error_msg))
@with_setup(_setup, _teardown)
@@ -247,7 +274,9 @@ def test_info_with_argument():
magic._print_endpoint_info = print_info_mock = MagicMock()
line = "hey"
session_info = ["1", "2"]
- spark_controller.get_all_sessions_endpoint_info = MagicMock(return_value=session_info)
+ spark_controller.get_all_sessions_endpoint_info = MagicMock(
+ return_value=session_info
+ )
magic.info(line)
@@ -260,24 +289,38 @@ def test_info_with_argument():
def test_info_unexpected_exception():
magic._print_endpoint_info = MagicMock()
line = ""
- spark_controller.get_all_sessions_endpoint = MagicMock(side_effect=ValueError('utter failure'))
+ spark_controller.get_all_sessions_endpoint = MagicMock(
+ side_effect=ValueError("utter failure")
+ )
magic.info(line)
- _assert_magic_failure_event_emitted_once('info', spark_controller.get_all_sessions_endpoint.side_effect)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
- .format(spark_controller.get_all_sessions_endpoint.side_effect))
+ _assert_magic_failure_event_emitted_once(
+ "info", spark_controller.get_all_sessions_endpoint.side_effect
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(
+ spark_controller.get_all_sessions_endpoint.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
def test_info_expected_exception():
magic._print_endpoint_info = MagicMock()
line = ""
- spark_controller.get_all_sessions_endpoint = MagicMock(side_effect=SqlContextNotFoundException('utter failure'))
+ spark_controller.get_all_sessions_endpoint = MagicMock(
+ side_effect=SqlContextNotFoundException("utter failure")
+ )
magic.info(line)
- _assert_magic_failure_event_emitted_once('info', spark_controller.get_all_sessions_endpoint.side_effect)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
- .format(spark_controller.get_all_sessions_endpoint.side_effect))
+ _assert_magic_failure_event_emitted_once(
+ "info", spark_controller.get_all_sessions_endpoint.side_effect
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(
+ spark_controller.get_all_sessions_endpoint.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
@@ -285,7 +328,7 @@ def test_help():
magic.help("")
assert_equals(ipython_display.html.call_count, 1)
- _assert_magic_successful_event_emitted_once('help')
+ _assert_magic_successful_event_emitted_once("help")
@with_setup(_setup, _teardown)
@@ -295,7 +338,7 @@ def test_help_with_cell_content():
assert_equals(ipython_display.send_error.call_count, 1)
assert_equals(ipython_display.html.call_count, 0)
- _assert_magic_failure_event_emitted_once('help', BadUserDataException(msg))
+ _assert_magic_failure_event_emitted_once("help", BadUserDataException(msg))
@with_setup(_setup, _teardown)
@@ -313,7 +356,7 @@ def test_logs():
magic.logs(line)
ipython_display.write.assert_called_once_with("No logs yet.")
- _assert_magic_successful_event_emitted_once('logs')
+ _assert_magic_successful_event_emitted_once("logs")
ipython_display.write.reset_mock()
@@ -333,7 +376,7 @@ def test_logs_with_cell_content():
magic.logs(line, cell="BOOP")
assert_equals(ipython_display.send_error.call_count, 1)
- _assert_magic_failure_event_emitted_once('logs', BadUserDataException(msg))
+ _assert_magic_failure_event_emitted_once("logs", BadUserDataException(msg))
@with_setup(_setup, _teardown)
@@ -350,12 +393,17 @@ def test_logs_unexpected_exception():
magic.session_started = True
- spark_controller.get_logs = MagicMock(side_effect=SyntaxError('There was some sort of error'))
+ spark_controller.get_logs = MagicMock(
+ side_effect=SyntaxError("There was some sort of error")
+ )
magic.logs(line)
spark_controller.get_logs.assert_called_once_with()
- _assert_magic_failure_event_emitted_once('logs', spark_controller.get_logs.side_effect)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
- .format(spark_controller.get_logs.side_effect))
+ _assert_magic_failure_event_emitted_once(
+ "logs", spark_controller.get_logs.side_effect
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(spark_controller.get_logs.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -364,12 +412,17 @@ def test_logs_expected_exception():
magic.session_started = True
- spark_controller.get_logs = MagicMock(side_effect=LivyUnexpectedStatusException('There was some sort of error'))
+ spark_controller.get_logs = MagicMock(
+ side_effect=LivyUnexpectedStatusException("There was some sort of error")
+ )
magic.logs(line)
spark_controller.get_logs.assert_called_once_with()
- _assert_magic_failure_event_emitted_once('logs', spark_controller.get_logs.side_effect)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
- .format(spark_controller.get_logs.side_effect))
+ _assert_magic_failure_event_emitted_once(
+ "logs", spark_controller.get_logs.side_effect
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(spark_controller.get_logs.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -379,26 +432,30 @@ def test_configure():
# Session not started
conf.override_all({})
- magic.configure('', '{"extra": "yes"}')
+ magic.configure("", '{"extra": "yes"}')
assert_equals(conf.session_configs(), {"extra": "yes"})
- _assert_magic_successful_event_emitted_once('configure')
+ _assert_magic_successful_event_emitted_once("configure")
magic.info.assert_called_once_with("")
# Session started - no -f
magic.session_started = True
conf.override_all({})
- magic.configure('', "{\"extra\": \"yes\"}")
+ magic.configure("", '{"extra": "yes"}')
assert_equals(conf.session_configs(), {})
assert_equals(ipython_display.send_error.call_count, 1)
# Session started - with -f
magic.info.reset_mock()
conf.override_all({})
- magic.configure("-f", "{\"extra\": \"yes\"}")
+ magic.configure("-f", '{"extra": "yes"}')
assert_equals(conf.session_configs(), {"extra": "yes"})
spark_controller.delete_session_by_name.assert_called_once_with(magic.session_name)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK, "extra": "yes"})
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK, "extra": "yes"},
+ )
magic.info.assert_called_once_with("")
@@ -406,31 +463,45 @@ def test_configure():
def test_configure_unexpected_exception():
magic.info = MagicMock()
- magic._override_session_settings = MagicMock(side_effect=ValueError('help'))
- magic.configure('', '{"extra": "yes"}')
- _assert_magic_failure_event_emitted_once('configure', magic._override_session_settings.side_effect)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG\
- .format(magic._override_session_settings.side_effect))
+ magic._override_session_settings = MagicMock(side_effect=ValueError("help"))
+ magic.configure("", '{"extra": "yes"}')
+ _assert_magic_failure_event_emitted_once(
+ "configure", magic._override_session_settings.side_effect
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(
+ magic._override_session_settings.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
def test_configure_expected_exception():
magic.info = MagicMock()
- magic._override_session_settings = MagicMock(side_effect=BadUserDataException('help'))
- magic.configure('', '{"extra": "yes"}')
- _assert_magic_failure_event_emitted_once('configure', magic._override_session_settings.side_effect)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG\
- .format(magic._override_session_settings.side_effect))
+ magic._override_session_settings = MagicMock(
+ side_effect=BadUserDataException("help")
+ )
+ magic.configure("", '{"extra": "yes"}')
+ _assert_magic_failure_event_emitted_once(
+ "configure", magic._override_session_settings.side_effect
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(
+ magic._override_session_settings.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
def test_configure_cant_parse_object_as_json():
magic.info = MagicMock()
- magic._override_session_settings = MagicMock(side_effect=BadUserDataException('help'))
- magic.configure('', "I CAN'T PARSE THIS AS JSON")
- _assert_magic_successful_event_emitted_once('configure')
+ magic._override_session_settings = MagicMock(
+ side_effect=BadUserDataException("help")
+ )
+ magic.configure("", "I CAN'T PARSE THIS AS JSON")
+ _assert_magic_successful_event_emitted_once("configure")
assert_equals(ipython_display.send_error.call_count, 1)
@@ -443,6 +514,7 @@ def test_get_session_settings():
assert magic.get_session_settings("something -f", True) == "something"
assert magic.get_session_settings("something", True) is None
+
@with_setup(_setup, _teardown)
def test_send_to_spark_with_non_empty_cell_error():
line = "-i input -n name -t str"
@@ -455,6 +527,7 @@ def test_send_to_spark_with_non_empty_cell_error():
assert_equals(ipython_display.send_error.call_count, 1)
+
@with_setup(_setup, _teardown)
def test_send_to_spark_with_no_i_param_error():
line = "-n name -t str"
@@ -467,6 +540,7 @@ def test_send_to_spark_with_no_i_param_error():
assert_equals(ipython_display.send_error.call_count, 1)
+
@with_setup(_setup, _teardown)
def test_send_to_spark_ok():
line = "-i input -n name -t str"
@@ -477,21 +551,32 @@ def test_send_to_spark_ok():
magic.send_to_spark(line, cell)
assert ipython_display.write.called
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
spark_controller.run_command.assert_called_once_with(Command(cell), None)
+
@with_setup(_setup, _teardown)
def test_spark():
line = ""
cell = "some spark code"
- spark_controller.run_command = MagicMock(return_value=(True, line, constants.MIMETYPE_TEXT_PLAIN))
+ spark_controller.run_command = MagicMock(
+ return_value=(True, line, constants.MIMETYPE_TEXT_PLAIN)
+ )
magic.spark(line, cell)
ipython_display.write.assert_called_once_with(line)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
spark_controller.run_command.assert_called_once_with(Command(cell), None)
@@ -510,13 +595,21 @@ def test_spark_with_argument():
def test_spark_error():
line = ""
cell = "some spark code"
- spark_controller.run_command = MagicMock(return_value=(False, line, constants.MIMETYPE_TEXT_PLAIN))
+ spark_controller.run_command = MagicMock(
+ return_value=(False, line, constants.MIMETYPE_TEXT_PLAIN)
+ )
magic.spark(line, cell)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG.format(line))
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(line)
+ )
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
spark_controller.run_command.assert_called_once_with(Command(cell), None)
@@ -538,37 +631,45 @@ def test_spark_failed_session_start():
def test_spark_unexpected_exception():
line = ""
cell = "some spark code"
- spark_controller.run_command = MagicMock(side_effect=Exception('oups'))
+ spark_controller.run_command = MagicMock(side_effect=Exception("oups"))
magic.spark(line, cell)
spark_controller.run_command.assert_called_once_with(Command(cell), None)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
- .format(spark_controller.run_command.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(spark_controller.run_command.side_effect)
+ )
@with_setup(_setup, _teardown)
def test_spark_expected_exception():
line = ""
cell = "some spark code"
- spark_controller.run_command = MagicMock(side_effect=SessionManagementException('oups'))
+ spark_controller.run_command = MagicMock(
+ side_effect=SessionManagementException("oups")
+ )
magic.spark(line, cell)
spark_controller.run_command.assert_called_once_with(Command(cell), None)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
- .format(spark_controller.run_command.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(spark_controller.run_command.side_effect)
+ )
@with_setup(_setup, _teardown)
@raises(SparkStatementException)
def test_spark_fatal_spark_statement_exception():
- conf.override_all({
- "all_errors_are_fatal": True,
- })
+ conf.override_all(
+ {
+ "all_errors_are_fatal": True,
+ }
+ )
line = ""
cell = "some spark code"
- spark_controller.run_command = MagicMock(side_effect=SparkStatementException('Oh no!'))
+ spark_controller.run_command = MagicMock(
+ side_effect=SparkStatementException("Oh no!")
+ )
magic.spark(line, cell)
@@ -577,28 +678,33 @@ def test_spark_fatal_spark_statement_exception():
def test_spark_unexpected_exception_in_storing():
line = "-o var_name"
cell = "some spark code"
- side_effect = [(True, 'ok', constants.MIMETYPE_TEXT_PLAIN), Exception('oups')]
+ side_effect = [(True, "ok", constants.MIMETYPE_TEXT_PLAIN), Exception("oups")]
spark_controller.run_command = MagicMock(side_effect=side_effect)
magic.spark(line, cell)
assert_equals(spark_controller.run_command.call_count, 2)
spark_controller.run_command.assert_any_call(Command(cell), None)
- ipython_display.send_error.assert_called_with(constants.INTERNAL_ERROR_MSG
- .format(side_effect[1]))
+ ipython_display.send_error.assert_called_with(
+ constants.INTERNAL_ERROR_MSG.format(side_effect[1])
+ )
@with_setup(_setup, _teardown)
def test_spark_expected_exception_in_storing():
line = "-o var_name"
cell = "some spark code"
- side_effect = [(True, 'ok', constants.MIMETYPE_TEXT_PLAIN), SessionManagementException('oups')]
+ side_effect = [
+ (True, "ok", constants.MIMETYPE_TEXT_PLAIN),
+ SessionManagementException("oups"),
+ ]
spark_controller.run_command = MagicMock(side_effect=side_effect)
magic.spark(line, cell)
assert spark_controller.run_command.call_count == 2
spark_controller.run_command.assert_any_call(Command(cell), None)
- ipython_display.send_error.assert_called_with(constants.EXPECTED_ERROR_MSG
- .format(side_effect[1]))
+ ipython_display.send_error.assert_called_with(
+ constants.EXPECTED_ERROR_MSG.format(side_effect[1])
+ )
@with_setup(_setup, _teardown)
@@ -608,7 +714,9 @@ def test_spark_sample_options():
magic.execute_spark = MagicMock()
ret = magic.spark(line, cell)
- magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None, True)
+ magic.execute_spark.assert_called_once_with(
+ cell, "var_name", "sample", 142, 0.3, None, True
+ )
@with_setup(_setup, _teardown)
@@ -618,7 +726,9 @@ def test_spark_false_coerce():
magic.execute_spark = MagicMock()
ret = magic.spark(line, cell)
- magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None, False)
+ magic.execute_spark.assert_called_once_with(
+ cell, "var_name", "sample", 142, 0.3, None, False
+ )
@with_setup(_setup, _teardown)
@@ -629,9 +739,15 @@ def test_sql_without_output():
magic.sql(line, cell)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
- magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, None, False, None)
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, None, None, None, None, None, False, None
+ )
@with_setup(_setup, _teardown)
@@ -642,33 +758,45 @@ def test_sql_with_output():
magic.sql(line, cell)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
- magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None)
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, None, None, None, None, "my_var", False, None
+ )
@with_setup(_setup, _teardown)
def test_sql_exception():
line = "-o my_var"
cell = "some spark code"
- magic.execute_sqlquery = MagicMock(side_effect=ValueError('HAHAHAHAH'))
+ magic.execute_sqlquery = MagicMock(side_effect=ValueError("HAHAHAHAH"))
magic.sql(line, cell)
- magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
- .format(magic.execute_sqlquery.side_effect))
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, None, None, None, None, "my_var", False, None
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(magic.execute_sqlquery.side_effect)
+ )
@with_setup(_setup, _teardown)
def test_sql_expected_exception():
line = "-o my_var"
cell = "some spark code"
- magic.execute_sqlquery = MagicMock(side_effect=HttpClientException('HAHAHAHAH'))
+ magic.execute_sqlquery = MagicMock(side_effect=HttpClientException("HAHAHAHAH"))
magic.sql(line, cell)
- magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
- .format(magic.execute_sqlquery.side_effect))
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, None, None, None, None, "my_var", False, None
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(magic.execute_sqlquery.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -692,9 +820,15 @@ def test_sql_quiet():
ret = magic.sql(line, cell)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
- magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "Output", True, None)
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, None, None, None, None, "Output", True, None
+ )
@with_setup(_setup, _teardown)
@@ -705,9 +839,15 @@ def test_sql_sample_options():
ret = magic.sql(line, cell)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
- magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True, True)
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, "sample", 142, 0.3, None, None, True, True
+ )
@with_setup(_setup, _teardown)
@@ -718,9 +858,15 @@ def test_sql_false_coerce():
ret = magic.sql(line, cell)
- spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_PYSPARK})
- magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True, False)
+ spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_PYSPARK},
+ )
+ magic.execute_sqlquery.assert_called_once_with(
+ cell, "sample", 142, 0.3, None, None, True, False
+ )
@with_setup(_setup, _teardown)
@@ -733,7 +879,7 @@ def test_cleanup_without_force():
magic.cleanup(line, cell)
- _assert_magic_successful_event_emitted_once('cleanup')
+ _assert_magic_successful_event_emitted_once("cleanup")
assert_equals(ipython_display.send_error.call_count, 1)
assert_equals(spark_controller.cleanup_endpoint.call_count, 0)
@@ -749,7 +895,7 @@ def test_cleanup_with_force():
magic.cleanup(line, cell)
- _assert_magic_successful_event_emitted_once('cleanup')
+ _assert_magic_successful_event_emitted_once("cleanup")
spark_controller.cleanup_endpoint.assert_called_once_with(magic.endpoint)
spark_controller.delete_session_by_name.assert_called_once_with(magic.session_name)
@@ -766,7 +912,7 @@ def test_cleanup_with_cell_content():
magic.cleanup(line, cell)
assert_equals(ipython_display.send_error.call_count, 1)
- _assert_magic_failure_event_emitted_once('cleanup', BadUserDataException(msg))
+ _assert_magic_failure_event_emitted_once("cleanup", BadUserDataException(msg))
@with_setup(_setup, _teardown)
@@ -774,13 +920,20 @@ def test_cleanup_exception():
line = "-f"
cell = ""
magic.session_started = True
- spark_controller.cleanup_endpoint = MagicMock(side_effect=ArithmeticError('DIVISION BY ZERO OH NO'))
+ spark_controller.cleanup_endpoint = MagicMock(
+ side_effect=ArithmeticError("DIVISION BY ZERO OH NO")
+ )
magic.cleanup(line, cell)
- _assert_magic_failure_event_emitted_once('cleanup', spark_controller.cleanup_endpoint.side_effect)
+ _assert_magic_failure_event_emitted_once(
+ "cleanup", spark_controller.cleanup_endpoint.side_effect
+ )
spark_controller.cleanup_endpoint.assert_called_once_with(magic.endpoint)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
- .format(spark_controller.cleanup_endpoint.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(
+ spark_controller.cleanup_endpoint.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
@@ -793,7 +946,7 @@ def test_delete_without_force():
magic.delete(line, cell)
- _assert_magic_successful_event_emitted_once('delete')
+ _assert_magic_successful_event_emitted_once("delete")
assert_equals(ipython_display.send_error.call_count, 1)
assert_equals(spark_controller.delete_session_by_id.call_count, 0)
@@ -809,7 +962,7 @@ def test_delete_without_session_id():
magic.delete(line, cell)
- _assert_magic_successful_event_emitted_once('delete')
+ _assert_magic_successful_event_emitted_once("delete")
assert_equals(ipython_display.send_error.call_count, 1)
assert_equals(spark_controller.delete_session_by_id.call_count, 0)
@@ -825,7 +978,7 @@ def test_delete_with_force_same_session():
magic.delete(line, cell)
- _assert_magic_successful_event_emitted_once('delete')
+ _assert_magic_successful_event_emitted_once("delete")
assert_equals(ipython_display.send_error.call_count, 1)
assert_equals(spark_controller.delete_session_by_id.call_count, 0)
@@ -842,10 +995,14 @@ def test_delete_with_force_none_session():
magic.delete(line, cell)
- _assert_magic_successful_event_emitted_once('delete')
+ _assert_magic_successful_event_emitted_once("delete")
- spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name)
- spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id)
+ spark_controller.get_session_id_for_client.assert_called_once_with(
+ magic.session_name
+ )
+ spark_controller.delete_session_by_id.assert_called_once_with(
+ magic.endpoint, session_id
+ )
@with_setup(_setup, _teardown)
@@ -860,7 +1017,7 @@ def test_delete_with_cell_content():
magic.delete(line, cell)
- _assert_magic_failure_event_emitted_once('delete', BadUserDataException(msg))
+ _assert_magic_failure_event_emitted_once("delete", BadUserDataException(msg))
assert_equals(ipython_display.send_error.call_count, 1)
@@ -875,10 +1032,14 @@ def test_delete_with_force_different_session():
magic.delete(line, cell)
- _assert_magic_successful_event_emitted_once('delete')
+ _assert_magic_successful_event_emitted_once("delete")
- spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name)
- spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id)
+ spark_controller.get_session_id_for_client.assert_called_once_with(
+ magic.session_name
+ )
+ spark_controller.delete_session_by_id.assert_called_once_with(
+ magic.endpoint, session_id
+ )
@with_setup(_setup, _teardown)
@@ -887,15 +1048,26 @@ def test_delete_exception():
session_id = 0
line = "-f -s {}".format(session_id)
cell = ""
- spark_controller.delete_session_by_id = MagicMock(side_effect=DataFrameParseException('wow'))
+ spark_controller.delete_session_by_id = MagicMock(
+ side_effect=DataFrameParseException("wow")
+ )
spark_controller.get_session_id_for_client = MagicMock()
magic.delete(line, cell)
- _assert_magic_failure_event_emitted_once('delete', spark_controller.delete_session_by_id.side_effect)
- spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name)
- spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id)
- ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
- .format(spark_controller.delete_session_by_id.side_effect))
+ _assert_magic_failure_event_emitted_once(
+ "delete", spark_controller.delete_session_by_id.side_effect
+ )
+ spark_controller.get_session_id_for_client.assert_called_once_with(
+ magic.session_name
+ )
+ spark_controller.delete_session_by_id.assert_called_once_with(
+ magic.endpoint, session_id
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.INTERNAL_ERROR_MSG.format(
+ spark_controller.delete_session_by_id.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
@@ -904,15 +1076,26 @@ def test_delete_expected_exception():
session_id = 0
line = "-f -s {}".format(session_id)
cell = ""
- spark_controller.delete_session_by_id = MagicMock(side_effect=LivyClientTimeoutException('wow'))
+ spark_controller.delete_session_by_id = MagicMock(
+ side_effect=LivyClientTimeoutException("wow")
+ )
spark_controller.get_session_id_for_client = MagicMock()
magic.delete(line, cell)
- _assert_magic_failure_event_emitted_once('delete', spark_controller.delete_session_by_id.side_effect)
- spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name)
- spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id)
- ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
- .format(spark_controller.delete_session_by_id.side_effect))
+ _assert_magic_failure_event_emitted_once(
+ "delete", spark_controller.delete_session_by_id.side_effect
+ )
+ spark_controller.get_session_id_for_client.assert_called_once_with(
+ magic.session_name
+ )
+ spark_controller.delete_session_by_id.assert_called_once_with(
+ magic.endpoint, session_id
+ )
+ ipython_display.send_error.assert_called_once_with(
+ constants.EXPECTED_ERROR_MSG.format(
+ spark_controller.delete_session_by_id.side_effect
+ )
+ )
@with_setup(_setup, _teardown)
@@ -925,8 +1108,12 @@ def test_start_session_displays_fatal_error_when_session_throws():
magic._do_not_call_start_session("")
- magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_SPARK})
+ magic.spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_SPARK},
+ )
assert magic.fatal_error
assert magic.fatal_error_message == conf.fatal_error_suggestion().format(str(e))
@@ -945,8 +1132,12 @@ def test_start_session_when_retry_fatal_error_is_not_allowed_by_default():
magic._do_not_call_start_session("")
# call add_session once and call send_error twice to show the error message
- magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_SPARK})
+ magic.spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_SPARK},
+ )
assert_equals(magic.ipython_display.send_error.call_count, 2)
@@ -961,8 +1152,12 @@ def test_retry_start_session_when_retry_fatal_error_is_allowed():
# first session creation - failed
session_created = magic._do_not_call_start_session("")
- magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_SPARK})
+ magic.spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_SPARK},
+ )
assert not session_created
assert magic.fatal_error
assert magic.fatal_error_message == conf.fatal_error_suggestion().format(str(e))
@@ -970,13 +1165,17 @@ def test_retry_start_session_when_retry_fatal_error_is_allowed():
# retry session creation - successful
magic.spark_controller.add_session = MagicMock()
session_created = magic._do_not_call_start_session("")
- magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
- {"kind": constants.SESSION_KIND_SPARK})
+ magic.spark_controller.add_session.assert_called_once_with(
+ magic.session_name,
+ magic.endpoint,
+ False,
+ {"kind": constants.SESSION_KIND_SPARK},
+ )
print(session_created)
assert session_created
assert magic.session_started
assert not magic.fatal_error
- assert magic.fatal_error_message == u''
+ assert magic.fatal_error_message == ""
# show error message only once
assert magic.ipython_display.send_error.call_count == 1
@@ -993,30 +1192,42 @@ def test_allow_retry_fatal():
def test_kernel_magics_names():
"""The magics machinery in IPython depends on the docstrings and
method names matching up correctly"""
- assert_equals(magic.help.__name__, 'help')
- assert_equals(magic.local.__name__, 'local')
- assert_equals(magic.info.__name__, 'info')
- assert_equals(magic.logs.__name__, 'logs')
- assert_equals(magic.configure.__name__, 'configure')
- assert_equals(magic.spark.__name__, 'spark')
- assert_equals(magic.sql.__name__, 'sql')
- assert_equals(magic.cleanup.__name__, 'cleanup')
- assert_equals(magic.delete.__name__, 'delete')
+ assert_equals(magic.help.__name__, "help")
+ assert_equals(magic.local.__name__, "local")
+ assert_equals(magic.info.__name__, "info")
+ assert_equals(magic.logs.__name__, "logs")
+ assert_equals(magic.configure.__name__, "configure")
+ assert_equals(magic.spark.__name__, "spark")
+ assert_equals(magic.sql.__name__, "sql")
+ assert_equals(magic.cleanup.__name__, "cleanup")
+ assert_equals(magic.delete.__name__, "delete")
def _assert_magic_successful_event_emitted_once(name):
magic._generate_uuid.assert_called_once_with()
- spark_events.emit_magic_execution_start_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK,
- magic._generate_uuid.return_value)
- spark_events.emit_magic_execution_end_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK,
- magic._generate_uuid.return_value, True,
- '', '')
+ spark_events.emit_magic_execution_start_event.assert_called_once_with(
+ name, constants.SESSION_KIND_PYSPARK, magic._generate_uuid.return_value
+ )
+ spark_events.emit_magic_execution_end_event.assert_called_once_with(
+ name,
+ constants.SESSION_KIND_PYSPARK,
+ magic._generate_uuid.return_value,
+ True,
+ "",
+ "",
+ )
def _assert_magic_failure_event_emitted_once(name, error):
magic._generate_uuid.assert_called_once_with()
- spark_events.emit_magic_execution_start_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK,
- magic._generate_uuid.return_value)
- spark_events.emit_magic_execution_end_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK,
- magic._generate_uuid.return_value, False,
- error.__class__.__name__, str(error))
+ spark_events.emit_magic_execution_start_event.assert_called_once_with(
+ name, constants.SESSION_KIND_PYSPARK, magic._generate_uuid.return_value
+ )
+ spark_events.emit_magic_execution_end_event.assert_called_once_with(
+ name,
+ constants.SESSION_KIND_PYSPARK,
+ magic._generate_uuid.return_value,
+ False,
+ error.__class__.__name__,
+ str(error),
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_kernels.py b/sparkmagic/sparkmagic/tests/test_kernels.py
index 0a2c7a0a2..6867659ab 100644
--- a/sparkmagic/sparkmagic/tests/test_kernels.py
+++ b/sparkmagic/sparkmagic/tests/test_kernels.py
@@ -26,18 +26,15 @@ def test_pyspark_kernel_configs():
kernel = TestPyparkKernel()
assert kernel.session_language == LANG_PYTHON
- assert kernel.implementation == 'PySpark'
+ assert kernel.implementation == "PySpark"
assert kernel.language == LANG_PYTHON
- assert kernel.language_version == '0.1'
+ assert kernel.language_version == "0.1"
assert kernel.language_info == {
- 'name': 'pyspark',
- 'mimetype': 'text/x-python',
- 'codemirror_mode': {
- 'name': 'python',
- 'version': 3
- },
- 'file_extension': '.py',
- 'pygments_lexer': 'python3',
+ "name": "pyspark",
+ "mimetype": "text/x-python",
+ "codemirror_mode": {"name": "python", "version": 3},
+ "file_extension": ".py",
+ "pygments_lexer": "python3",
}
@@ -46,15 +43,15 @@ def test_spark_kernel_configs():
assert kernel.session_language == LANG_SCALA
- assert kernel.implementation == 'Spark'
+ assert kernel.implementation == "Spark"
assert kernel.language == LANG_SCALA
- assert kernel.language_version == '0.1'
+ assert kernel.language_version == "0.1"
assert kernel.language_info == {
- 'name': 'scala',
- 'mimetype': 'text/x-scala',
- 'pygments_lexer': 'scala',
- 'file_extension': '.sc',
- 'codemirror_mode': 'text/x-scala',
+ "name": "scala",
+ "mimetype": "text/x-scala",
+ "pygments_lexer": "scala",
+ "file_extension": ".sc",
+ "codemirror_mode": "text/x-scala",
}
@@ -63,13 +60,13 @@ def test_sparkr_kernel_configs():
assert kernel.session_language == LANG_R
- assert kernel.implementation == 'SparkR'
+ assert kernel.implementation == "SparkR"
assert kernel.language == LANG_R
- assert kernel.language_version == '0.1'
+ assert kernel.language_version == "0.1"
assert kernel.language_info == {
- 'name': 'sparkR',
- 'mimetype': 'text/x-rsrc',
- 'pygments_lexer': 'r',
- 'file_extension': '.r',
- 'codemirror_mode': 'text/x-rsrc'
+ "name": "sparkR",
+ "mimetype": "text/x-rsrc",
+ "pygments_lexer": "r",
+ "file_extension": ".r",
+ "codemirror_mode": "text/x-rsrc",
}
diff --git a/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py b/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py
index 9d4c77091..4f0beb87a 100644
--- a/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py
+++ b/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py
@@ -13,7 +13,7 @@
def test_post_statement():
http_client = MagicMock()
livy_client = LivyReliableHttpClient(http_client, None)
- data = {"adlfj":"sadflkjsdf"}
+ data = {"adlfj": "sadflkjsdf"}
out = livy_client.post_statement(100, data)
assert_equals(out, http_client.post.return_value.json.return_value)
http_client.post.assert_called_once_with("/sessions/100/statements", [201], data)
@@ -32,7 +32,9 @@ def test_cancel_statement():
livy_client = LivyReliableHttpClient(http_client, None)
out = livy_client.cancel_statement(100, 104)
assert_equals(out, http_client.post.return_value.json.return_value)
- http_client.post.assert_called_once_with("/sessions/100/statements/104/cancel", [200], {})
+ http_client.post.assert_called_once_with(
+ "/sessions/100/statements/104/cancel", [200], {}
+ )
def test_get_sessions():
@@ -46,7 +48,7 @@ def test_get_sessions():
def test_post_session():
http_client = MagicMock()
livy_client = LivyReliableHttpClient(http_client, None)
- properties = {"adlfj":"sadflkjsdf", 1: [2,3,4,5]}
+ properties = {"adlfj": "sadflkjsdf", 1: [2, 3, 4, 5]}
out = livy_client.post_session(properties)
assert_equals(out, http_client.post.return_value.json.return_value)
http_client.post.assert_called_once_with("/sessions", [201], properties)
@@ -77,7 +79,7 @@ def test_get_all_session_logs():
def test_custom_headers():
custom_headers = {"header1": "value1"}
- overrides = { conf.custom_headers.__name__: custom_headers }
+ overrides = {conf.custom_headers.__name__: custom_headers}
conf.override_all(overrides)
endpoint = Endpoint("http://url.com", None)
client = LivyReliableHttpClient.from_endpoint(endpoint)
@@ -113,5 +115,5 @@ def test_retry_policy():
def _override_policy(policy):
- overrides = { conf.retry_policy.__name__: policy }
+ overrides = {conf.retry_policy.__name__: policy}
conf.override_all(overrides)
diff --git a/sparkmagic/sparkmagic/tests/test_livysession.py b/sparkmagic/sparkmagic/tests/test_livysession.py
index 111362bc8..2407daf2c 100644
--- a/sparkmagic/sparkmagic/tests/test_livysession.py
+++ b/sparkmagic/sparkmagic/tests/test_livysession.py
@@ -4,8 +4,12 @@
import sparkmagic.utils.constants as constants
import sparkmagic.utils.configuration as conf
-from sparkmagic.livyclientlib.exceptions import LivyClientTimeoutException, LivyUnexpectedStatusException,\
- BadUserDataException, SqlContextNotFoundException
+from sparkmagic.livyclientlib.exceptions import (
+ LivyClientTimeoutException,
+ LivyUnexpectedStatusException,
+ BadUserDataException,
+ SqlContextNotFoundException,
+)
from sparkmagic.livyclientlib.livysession import LivySession
@@ -27,27 +31,43 @@ class TestLivySession(object):
pi_result = "Pi is roughly 3.14336"
- session_create_json = json.loads('{"id":0,"state":"starting","kind":"spark","log":[]}')
- resource_limit_json = json.loads('{"id":0,"state":"starting","kind":"spark","log":['
- '"Queue\'s AM resource limit exceeded."]}')
- ready_sessions_json = json.loads('{"id":0,"state":"idle","kind":"spark","log":[""]}')
- recovering_sessions_json = json.loads('{"id":0,"state":"recovering","kind":"spark","log":[""]}')
- error_sessions_json = json.loads('{"id":0,"state":"error","kind":"spark","log":[""]}')
+ session_create_json = json.loads(
+ '{"id":0,"state":"starting","kind":"spark","log":[]}'
+ )
+ resource_limit_json = json.loads(
+ '{"id":0,"state":"starting","kind":"spark","log":['
+ '"Queue\'s AM resource limit exceeded."]}'
+ )
+ ready_sessions_json = json.loads(
+ '{"id":0,"state":"idle","kind":"spark","log":[""]}'
+ )
+ recovering_sessions_json = json.loads(
+ '{"id":0,"state":"recovering","kind":"spark","log":[""]}'
+ )
+ error_sessions_json = json.loads(
+ '{"id":0,"state":"error","kind":"spark","log":[""]}'
+ )
busy_sessions_json = json.loads('{"id":0,"state":"busy","kind":"spark","log":[""]}')
post_statement_json = json.loads('{"id":0,"state":"running","output":null}')
waiting_statement_json = json.loads('{"id":0,"state":"waiting","output":null}')
running_statement_json = json.loads('{"id":0,"state":"running","output":null}')
- ready_statement_json = json.loads('{"id":0,"state":"available","output":{"status":"ok",'
- '"execution_count":0,"data":{"text/plain":"Pi is roughly 3.14336"}}}')
- ready_statement_null_output_json = json.loads('{"id":0,"state":"available","output":null}')
- ready_statement_failed_json = json.loads('{"id":0,"state":"available","output":{"status":"error",'
- '"evalue":"error","traceback":"error"}}')
+ ready_statement_json = json.loads(
+ '{"id":0,"state":"available","output":{"status":"ok",'
+ '"execution_count":0,"data":{"text/plain":"Pi is roughly 3.14336"}}}'
+ )
+ ready_statement_null_output_json = json.loads(
+ '{"id":0,"state":"available","output":null}'
+ )
+ ready_statement_failed_json = json.loads(
+ '{"id":0,"state":"available","output":{"status":"error",'
+ '"evalue":"error","traceback":"error"}}'
+ )
log_json = json.loads('{"id":6,"from":0,"total":212,"log":["hi","hi"]}')
def __init__(self):
self.http_client = None
self.spark_events = None
-
+
self.get_statement_responses = []
self.post_statement_responses = []
self.get_session_responses = []
@@ -78,16 +98,19 @@ def _next_session_response_post(self, *args):
self.post_session_responses = self.post_session_responses[1:]
return val
- def _create_session(self, kind=constants.SESSION_KIND_SPARK, session_id=-1,
- heartbeat_timeout=60):
+ def _create_session(
+ self, kind=constants.SESSION_KIND_SPARK, session_id=-1, heartbeat_timeout=60
+ ):
ipython_display = MagicMock()
- session = LivySession(self.http_client,
- {"kind": kind},
- ipython_display,
- session_id,
- self.spark_events,
- heartbeat_timeout,
- self.heartbeat_thread)
+ session = LivySession(
+ self.http_client,
+ {"kind": kind},
+ ipython_display,
+ session_id,
+ self.spark_events,
+ heartbeat_timeout,
+ self.heartbeat_thread,
+ )
return session
def _create_session_with_fixed_get_response(self, get_session_json):
@@ -109,22 +132,22 @@ def test_constructor_starts_with_existing_session(self):
assert session.id == session_id
assert session._heartbeat_thread is None
- assert constants.LIVY_HEARTBEAT_TIMEOUT_PARAM not in list(session.properties.keys())
-
+ assert constants.LIVY_HEARTBEAT_TIMEOUT_PARAM not in list(
+ session.properties.keys()
+ )
+
def test_constructor_starts_heartbeat_with_existing_session(self):
- conf.override_all({
- "heartbeat_refresh_seconds": 0.1
- })
+ conf.override_all({"heartbeat_refresh_seconds": 0.1})
session_id = 1
session = self._create_session(session_id=session_id)
conf.override_all({})
-
+
assert session.id == session_id
assert self.heartbeat_thread.daemon
self.heartbeat_thread.start.assert_called_once_with()
assert not session._heartbeat_thread is None
- assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM ] > 0
-
+ assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM] > 0
+
def test_start_with_heartbeat(self):
self.http_client.post_session.return_value = self.session_create_json
self.http_client.get_session.return_value = self.ready_sessions_json
@@ -132,12 +155,12 @@ def test_start_with_heartbeat(self):
session = self._create_session()
session.start()
-
+
assert self.heartbeat_thread.daemon
self.heartbeat_thread.start.assert_called_once_with()
assert not session._heartbeat_thread is None
- assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM ] > 0
-
+ assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM] > 0
+
def test_start_with_heartbeat_calls_only_once(self):
self.http_client.post_session.return_value = self.session_create_json
self.http_client.get_session.return_value = self.ready_sessions_json
@@ -151,7 +174,7 @@ def test_start_with_heartbeat_calls_only_once(self):
assert self.heartbeat_thread.daemon
self.heartbeat_thread.start.assert_called_once_with()
assert not session._heartbeat_thread is None
-
+
def test_delete_with_heartbeat(self):
self.http_client.post_session.return_value = self.session_create_json
self.http_client.get_session.return_value = self.ready_sessions_json
@@ -160,9 +183,9 @@ def test_delete_with_heartbeat(self):
session = self._create_session()
session.start()
heartbeat_thread = session._heartbeat_thread
-
+
session.delete()
-
+
self.heartbeat_thread.stop.assert_called_once_with()
assert session._heartbeat_thread is None
@@ -195,7 +218,9 @@ def test_start_scala_starts_session(self):
assert_equals(kind, session.kind)
assert_equals("idle", session.status)
assert_equals(0, session.id)
- self.http_client.post_session.assert_called_with({"kind": "spark", "heartbeatTimeoutInSecond": 60})
+ self.http_client.post_session.assert_called_with(
+ {"kind": "spark", "heartbeatTimeoutInSecond": 60}
+ )
def test_start_r_starts_session(self):
self.http_client.post_session.return_value = self.session_create_json
@@ -209,7 +234,9 @@ def test_start_r_starts_session(self):
assert_equals(kind, session.kind)
assert_equals("idle", session.status)
assert_equals(0, session.id)
- self.http_client.post_session.assert_called_with({"kind": "sparkr", "heartbeatTimeoutInSecond": 60})
+ self.http_client.post_session.assert_called_with(
+ {"kind": "sparkr", "heartbeatTimeoutInSecond": 60}
+ )
def test_start_python_starts_session(self):
self.http_client.post_session.return_value = self.session_create_json
@@ -223,7 +250,9 @@ def test_start_python_starts_session(self):
assert_equals(kind, session.kind)
assert_equals("idle", session.status)
assert_equals(0, session.id)
- self.http_client.post_session.assert_called_with({"kind": "pyspark", "heartbeatTimeoutInSecond": 60})
+ self.http_client.post_session.assert_called_with(
+ {"kind": "pyspark", "heartbeatTimeoutInSecond": 60}
+ )
def test_start_passes_in_all_properties(self):
self.http_client.post_session.return_value = self.session_create_json
@@ -257,12 +286,14 @@ def test_status_recovering(self):
Ensure 'recovering' state is supported: we go from recovering to idle.
"""
self.http_client.post_session.return_value = self.session_create_json
+
def get_session(i, calls=[]):
if not calls:
calls.append(1)
return self.recovering_sessions_json
else:
return self.ready_sessions_json
+
self.http_client.get_session.side_effect = get_session
self.http_client.get_statement.return_value = self.ready_statement_json
session = self._create_session()
@@ -284,16 +315,18 @@ def test_logs_gets_latest_logs(self):
def test_wait_for_idle_returns_when_in_state(self):
self.http_client.post_session.return_value = self.session_create_json
- self.get_session_responses = [self.ready_sessions_json,
- self.ready_sessions_json,
- self.busy_sessions_json,
- self.ready_sessions_json]
+ self.get_session_responses = [
+ self.ready_sessions_json,
+ self.ready_sessions_json,
+ self.busy_sessions_json,
+ self.ready_sessions_json,
+ ]
self.http_client.get_session.side_effect = self._next_session_response_get
self.http_client.get_statement.return_value = self.ready_statement_json
session = self._create_session()
session.get_row_html = MagicMock()
- session.get_row_html.return_value = u"""row1 |
"""
+ session.get_row_html.return_value = """row1 |
"""
session.start()
@@ -304,17 +337,19 @@ def test_wait_for_idle_returns_when_in_state(self):
def test_wait_for_idle_prints_resource_limit_message(self):
self.http_client.post_session.return_value = self.session_create_json
- self.get_session_responses = [self.resource_limit_json,
- self.ready_sessions_json,
- self.ready_sessions_json,
- self.ready_sessions_json]
+ self.get_session_responses = [
+ self.resource_limit_json,
+ self.ready_sessions_json,
+ self.ready_sessions_json,
+ self.ready_sessions_json,
+ ]
self.http_client.get_session.side_effect = self._next_session_response_get
self.http_client.get_statement.return_value = self.ready_statement_json
self.http_client.get_all_session_logs.return_value = self.log_json
session = self._create_session()
session.get_row_html = MagicMock()
- session.get_row_html.return_value = u"""row1 |
"""
+ session.get_row_html.return_value = """row1 |
"""
session.start()
@@ -324,16 +359,18 @@ def test_wait_for_idle_prints_resource_limit_message(self):
@raises(LivyUnexpectedStatusException)
def test_wait_for_idle_throws_when_in_final_status(self):
self.http_client.post_session.return_value = self.session_create_json
- self.get_session_responses = [self.ready_sessions_json,
- self.busy_sessions_json,
- self.busy_sessions_json,
- self.error_sessions_json]
+ self.get_session_responses = [
+ self.ready_sessions_json,
+ self.busy_sessions_json,
+ self.busy_sessions_json,
+ self.error_sessions_json,
+ ]
self.http_client.get_session.side_effect = self._next_session_response_get
self.http_client.get_all_session_logs.return_value = self.log_json
session = self._create_session()
session.get_row_html = MagicMock()
- session.get_row_html.return_value = u"""row1 |
"""
+ session.get_row_html.return_value = """row1 |
"""
session.start()
@@ -342,17 +379,19 @@ def test_wait_for_idle_throws_when_in_final_status(self):
@raises(LivyClientTimeoutException)
def test_wait_for_idle_times_out(self):
self.http_client.post_session.return_value = self.session_create_json
- self.get_session_responses = [self.ready_sessions_json,
- self.ready_sessions_json,
- self.busy_sessions_json,
- self.busy_sessions_json,
- self.ready_sessions_json]
+ self.get_session_responses = [
+ self.ready_sessions_json,
+ self.ready_sessions_json,
+ self.busy_sessions_json,
+ self.busy_sessions_json,
+ self.ready_sessions_json,
+ ]
self.http_client.get_session.side_effect = self._next_session_response_get
self.http_client.get_statement.return_value = self.ready_statement_json
session = self._create_session()
session.get_row_html = MagicMock()
- session.get_row_html.return_value = u"""row1 |
"""
+ session.get_row_html.return_value = """row1 |
"""
session.start()
@@ -395,9 +434,12 @@ def test_start_emits_start_end_session(self):
session = self._create_session(kind=kind)
session.start()
- self.spark_events.emit_session_creation_start_event.assert_called_once_with(session.guid, kind)
+ self.spark_events.emit_session_creation_start_event.assert_called_once_with(
+ session.guid, kind
+ )
self.spark_events.emit_session_creation_end_event.assert_called_once_with(
- session.guid, kind, session.id, session.status, True, "", "")
+ session.guid, kind, session.id, session.status, True, "", ""
+ )
def test_start_emits_start_end_failed_session_when_bad_status(self):
self.http_client.post_session.side_effect = ValueError
@@ -412,9 +454,12 @@ def test_start_emits_start_end_failed_session_when_bad_status(self):
except ValueError:
pass
- self.spark_events.emit_session_creation_start_event.assert_called_once_with(session.guid, kind)
+ self.spark_events.emit_session_creation_start_event.assert_called_once_with(
+ session.guid, kind
+ )
self.spark_events.emit_session_creation_end_event.assert_called_once_with(
- session.guid, kind, session.id, session.status, False, "ValueError", "")
+ session.guid, kind, session.id, session.status, False, "ValueError", ""
+ )
def test_start_emits_start_end_failed_session_when_wait_for_idle_throws(self):
self.http_client.post_session.return_value = self.session_create_json
@@ -430,9 +475,12 @@ def test_start_emits_start_end_failed_session_when_wait_for_idle_throws(self):
except ValueError:
pass
- self.spark_events.emit_session_creation_start_event.assert_called_once_with(session.guid, kind)
+ self.spark_events.emit_session_creation_start_event.assert_called_once_with(
+ session.guid, kind
+ )
self.spark_events.emit_session_creation_end_event.assert_called_once_with(
- session.guid, kind, session.id, session.status, False, "ValueError", "")
+ session.guid, kind, session.id, session.status, False, "ValueError", ""
+ )
def test_delete_session_emits_start_end(self):
self.http_client.post_session.return_value = self.session_create_json
@@ -449,9 +497,17 @@ def test_delete_session_emits_start_end(self):
assert_equals(session.id, -1)
self.spark_events.emit_session_deletion_start_event.assert_called_once_with(
- session.guid, session.kind, end_id, end_status)
+ session.guid, session.kind, end_id, end_status
+ )
self.spark_events.emit_session_deletion_end_event.assert_called_once_with(
- session.guid, session.kind, end_id, constants.DEAD_SESSION_STATUS, True, "", "")
+ session.guid,
+ session.kind,
+ end_id,
+ constants.DEAD_SESSION_STATUS,
+ True,
+ "",
+ "",
+ )
def test_delete_session_emits_start_failed_end_when_delete_throws(self):
self.http_client.delete_session.side_effect = ValueError
@@ -472,9 +528,11 @@ def test_delete_session_emits_start_failed_end_when_delete_throws(self):
pass
self.spark_events.emit_session_deletion_start_event.assert_called_once_with(
- session.guid, session.kind, end_id, end_status)
+ session.guid, session.kind, end_id, end_status
+ )
self.spark_events.emit_session_deletion_end_event.assert_called_once_with(
- session.guid, session.kind, end_id, end_status, False, "ValueError", "")
+ session.guid, session.kind, end_id, end_status, False, "ValueError", ""
+ )
def test_delete_session_emits_start_failed_end_when_in_bad_state(self):
self.http_client.get_session.return_value = self.ready_sessions_json
@@ -491,9 +549,17 @@ def test_delete_session_emits_start_failed_end_when_in_bad_state(self):
assert_equals(0, session.ipython_display.send_error.call_count)
self.spark_events.emit_session_deletion_start_event.assert_called_once_with(
- session.guid, session.kind, end_id, end_status)
+ session.guid, session.kind, end_id, end_status
+ )
self.spark_events.emit_session_deletion_end_event.assert_called_once_with(
- session.guid, session.kind, end_id, constants.DEAD_SESSION_STATUS, True, "", "")
+ session.guid,
+ session.kind,
+ end_id,
+ constants.DEAD_SESSION_STATUS,
+ True,
+ "",
+ "",
+ )
def test_get_empty_app_id(self):
self._verify_get_app_id("null", None, 7)
@@ -502,20 +568,22 @@ def test_get_missing_app_id(self):
self._verify_get_app_id(None, None, 7)
def test_get_normal_app_id(self):
- self._verify_get_app_id("\"app_id_123\"", "app_id_123", 6)
+ self._verify_get_app_id('"app_id_123"', "app_id_123", 6)
def test_get_empty_driver_log_url(self):
self._verify_get_driver_log_url("null", None)
def test_get_normal_driver_log_url(self):
- self._verify_get_driver_log_url("\"http://example.com\"", "http://example.com")
+ self._verify_get_driver_log_url('"http://example.com"', "http://example.com")
def test_missing_app_info_get_driver_log_url(self):
self._verify_get_driver_log_url_json(self.ready_sessions_json, None)
-
+
def _verify_get_app_id(self, mock_app_id, expected_app_id, expected_call_count):
- mock_field = ",\"appId\":" + mock_app_id if mock_app_id is not None else ""
- get_session_json = json.loads('{"id":0,"state":"idle","output":null%s,"log":""}' % mock_field)
+ mock_field = ',"appId":' + mock_app_id if mock_app_id is not None else ""
+ get_session_json = json.loads(
+ '{"id":0,"state":"idle","output":null%s,"log":""}' % mock_field
+ )
session = self._create_session_with_fixed_get_response(get_session_json)
app_id = session.get_app_id()
@@ -524,8 +592,14 @@ def _verify_get_app_id(self, mock_app_id, expected_app_id, expected_call_count):
assert_equals(expected_call_count, self.http_client.get_session.call_count)
def _verify_get_driver_log_url(self, mock_driver_log_url, expected_url):
- mock_field = "\"driverLogUrl\":" + mock_driver_log_url if mock_driver_log_url is not None else ""
- session_json = json.loads('{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field)
+ mock_field = (
+ '"driverLogUrl":' + mock_driver_log_url
+ if mock_driver_log_url is not None
+ else ""
+ )
+ session_json = json.loads(
+ '{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field
+ )
self._verify_get_driver_log_url_json(session_json, expected_url)
def _verify_get_driver_log_url_json(self, get_session_json, expected_url):
@@ -540,14 +614,18 @@ def test_get_empty_spark_ui_url(self):
self._verify_get_spark_ui_url("null", None)
def test_get_normal_spark_ui_url(self):
- self._verify_get_spark_ui_url("\"http://example.com\"", "http://example.com")
+ self._verify_get_spark_ui_url('"http://example.com"', "http://example.com")
def test_missing_app_info_get_spark_ui_url(self):
self._verify_get_spark_ui_url_json(self.ready_sessions_json, None)
def _verify_get_spark_ui_url(self, mock_spark_ui_url, expected_url):
- mock_field = "\"sparkUiUrl\":" + mock_spark_ui_url if mock_spark_ui_url is not None else ""
- session_json = json.loads('{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field)
+ mock_field = (
+ '"sparkUiUrl":' + mock_spark_ui_url if mock_spark_ui_url is not None else ""
+ )
+ session_json = json.loads(
+ '{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field
+ )
self._verify_get_spark_ui_url_json(session_json, expected_url)
def _verify_get_spark_ui_url_json(self, get_session_json, expected_url):
@@ -565,38 +643,48 @@ def test_get_row_html(self):
session1.get_spark_ui_url = MagicMock()
session1.get_driver_log_url = MagicMock()
session1.get_user = MagicMock()
- session1.get_app_id.return_value = 'app1234'
+ session1.get_app_id.return_value = "app1234"
session1.status = constants.IDLE_SESSION_STATUS
- session1.get_spark_ui_url.return_value = 'https://microsoft.com/sparkui'
- session1.get_driver_log_url.return_value = 'https://microsoft.com/driverlog'
- session1.get_user.return_value = 'userTest'
+ session1.get_spark_ui_url.return_value = "https://microsoft.com/sparkui"
+ session1.get_driver_log_url.return_value = "https://microsoft.com/driverlog"
+ session1.get_user.return_value = "userTest"
html1 = session1.get_row_html(1)
- assert_equals(html1, u"""1 | app1234 | spark | idle | Link | Link | userTest | \u2714 |
""")
+ assert_equals(
+ html1,
+ """1 | app1234 | spark | idle | Link | Link | userTest | \u2714 |
""",
+ )
session_id2 = 3
- session2 = self._create_session(kind=constants.SESSION_KIND_PYSPARK,
- session_id=session_id2)
+ session2 = self._create_session(
+ kind=constants.SESSION_KIND_PYSPARK, session_id=session_id2
+ )
session2.get_app_id = MagicMock()
session2.get_spark_ui_url = MagicMock()
session2.get_driver_log_url = MagicMock()
session2.get_user = MagicMock()
- session2.get_app_id.return_value = 'app5069'
+ session2.get_app_id.return_value = "app5069"
session2.status = constants.BUSY_SESSION_STATUS
session2.get_spark_ui_url.return_value = None
session2.get_driver_log_url.return_value = None
- session2.get_user.return_value = 'userTest2'
+ session2.get_user.return_value = "userTest2"
html2 = session2.get_row_html(1)
- assert_equals(html2, u"""3 | app5069 | pyspark | busy | | | userTest2 | |
""")
+ assert_equals(
+ html2,
+ """3 | app5069 | pyspark | busy | | | userTest2 | |
""",
+ )
def test_link(self):
- url = u"https://microsoft.com"
- assert_equals(LivySession.get_html_link(u'Link', url), u"""Link""")
+ url = "https://microsoft.com"
+ assert_equals(
+ LivySession.get_html_link("Link", url),
+ """Link""",
+ )
url = None
- assert_equals(LivySession.get_html_link(u'Link', url), u"")
+ assert_equals(LivySession.get_html_link("Link", url), "")
def test_spark_session_available(self):
self.http_client.post_session.return_value = self.session_create_json
@@ -604,13 +692,15 @@ def test_spark_session_available(self):
self.http_client.get_statement.return_value = self.ready_statement_json
session = self._create_session()
session.start()
- assert_equals(session.sql_context_variable_name,"spark")
+ assert_equals(session.sql_context_variable_name, "spark")
def test_sql_context_available(self):
self.http_client.post_session.return_value = self.session_create_json
self.http_client.get_session.return_value = self.ready_sessions_json
- self.get_statement_responses = [self.ready_statement_failed_json,
- self.ready_statement_json]
+ self.get_statement_responses = [
+ self.ready_statement_failed_json,
+ self.ready_statement_json,
+ ]
self.http_client.get_statement.side_effect = self._next_statement_response_get
session = self._create_session()
session.start()
diff --git a/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py b/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py
index 6a3d8750a..9b52571b6 100644
--- a/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py
+++ b/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py
@@ -5,8 +5,10 @@
def test_no_coercing():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': u'random', u'temp_diff': u'0adsf'}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": "random", "temp_diff": "0adsf"},
+ ]
desired_df = pd.DataFrame(records)
df = pd.DataFrame(records)
@@ -16,8 +18,10 @@ def test_no_coercing():
def test_date_coercing():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': u'0adsf'}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": "6/1/13", "temp_diff": "0adsf"},
+ ]
desired_df = pd.DataFrame(records)
desired_df["date"] = pd.to_datetime(desired_df["date"])
@@ -28,8 +32,10 @@ def test_date_coercing():
def test_date_coercing_none_values():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': None, u'temp_diff': u'0adsf'}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": None, "temp_diff": "0adsf"},
+ ]
desired_df = pd.DataFrame(records)
desired_df["date"] = pd.to_datetime(desired_df["date"])
@@ -40,9 +46,11 @@ def test_date_coercing_none_values():
def test_date_none_values_and_no_coercing():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': None, u'temp_diff': u'0adsf'},
- {u'buildingID': 1, u'date': u'adsf', u'temp_diff': u'0adsf'}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": None, "temp_diff": "0adsf"},
+ {"buildingID": 1, "date": "adsf", "temp_diff": "0adsf"},
+ ]
desired_df = pd.DataFrame(records)
df = pd.DataFrame(records)
@@ -52,8 +60,10 @@ def test_date_none_values_and_no_coercing():
def test_numeric_coercing():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': u'adsf', u'temp_diff': u'0'}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": "adsf", "temp_diff": "0"},
+ ]
desired_df = pd.DataFrame(records)
desired_df["temp_diff"] = pd.to_numeric(desired_df["temp_diff"])
@@ -64,8 +74,10 @@ def test_numeric_coercing():
def test_numeric_coercing_none_values():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': u'asdf', u'temp_diff': None}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": "asdf", "temp_diff": None},
+ ]
desired_df = pd.DataFrame(records)
desired_df["temp_diff"] = pd.to_numeric(desired_df["temp_diff"])
@@ -76,9 +88,11 @@ def test_numeric_coercing_none_values():
def test_numeric_none_values_and_no_coercing():
- records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'},
- {u'buildingID': 1, u'date': u'asdf', u'temp_diff': None},
- {u'buildingID': 1, u'date': u'adsf', u'temp_diff': u'0asdf'}]
+ records = [
+ {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"},
+ {"buildingID": 1, "date": "asdf", "temp_diff": None},
+ {"buildingID": 1, "date": "adsf", "temp_diff": "0asdf"},
+ ]
desired_df = pd.DataFrame(records)
df = pd.DataFrame(records)
@@ -125,18 +139,18 @@ def test_df_dict_does_not_throw():
def test_overflow_coercing():
- records = [{'_c0':'12345678901'}]
+ records = [{"_c0": "12345678901"}]
desired_df = pd.DataFrame(records)
- desired_df['_c0'] = pd.to_numeric(desired_df['_c0'])
+ desired_df["_c0"] = pd.to_numeric(desired_df["_c0"])
df = pd.DataFrame(records)
coerce_pandas_df_to_numeric_datetime(df)
assert_frame_equal(desired_df, df)
-
+
def test_all_null_columns():
- records = [{'_c0':'12345', 'nulla': None}, {'_c0':'12345', 'nulla': None}]
+ records = [{"_c0": "12345", "nulla": None}, {"_c0": "12345", "nulla": None}]
desired_df = pd.DataFrame(records)
- desired_df['_c0'] = pd.to_numeric(desired_df['_c0'])
+ desired_df["_c0"] = pd.to_numeric(desired_df["_c0"])
df = pd.DataFrame(records)
coerce_pandas_df_to_numeric_datetime(df)
assert_frame_equal(desired_df, df)
diff --git a/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py b/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py
index 61851b41b..de0d9dcc6 100644
--- a/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py
+++ b/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py
@@ -2,7 +2,14 @@
# Distributed under the terms of the Modified BSD License.
from mock import patch, PropertyMock, MagicMock
-from nose.tools import raises, assert_equals, with_setup, assert_is_not_none, assert_false, assert_true
+from nose.tools import (
+ raises,
+ assert_equals,
+ with_setup,
+ assert_is_not_none,
+ assert_false,
+ assert_true,
+)
import requests
from requests_kerberos.kerberos_ import HTTPKerberosAuth, REQUIRED, OPTIONAL
from sparkmagic.auth.basic import Basic
@@ -60,7 +67,7 @@ def test_compose_url():
@with_setup(_setup, _teardown)
def test_get():
- with patch('requests.Session.get') as patched_get:
+ with patch("requests.Session.get") as patched_get:
type(patched_get.return_value).status_code = 200
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -73,7 +80,7 @@ def test_get():
@raises(HttpClientException)
@with_setup(_setup, _teardown)
def test_get_throws():
- with patch('requests.Session.get') as patched_get:
+ with patch("requests.Session.get") as patched_get:
type(patched_get.return_value).status_code = 500
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -88,7 +95,7 @@ def test_get_will_retry():
retry_policy.should_retry.return_value = True
retry_policy.seconds_to_sleep.return_value = 0.01
- with patch('requests.Session.get') as patched_get:
+ with patch("requests.Session.get") as patched_get:
# When we call assert_equals in this unit test, the side_effect is executed.
# So, the last status_code should be repeated.
sequential_values = [500, 200, 200]
@@ -106,7 +113,7 @@ def test_get_will_retry():
@with_setup(_setup, _teardown)
def test_post():
- with patch('requests.Session.post') as patched_post:
+ with patch("requests.Session.post") as patched_post:
type(patched_post.return_value).status_code = 200
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -119,7 +126,7 @@ def test_post():
@raises(HttpClientException)
@with_setup(_setup, _teardown)
def test_post_throws():
- with patch('requests.Session.post') as patched_post:
+ with patch("requests.Session.post") as patched_post:
type(patched_post.return_value).status_code = 500
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -134,7 +141,7 @@ def test_post_will_retry():
retry_policy.should_retry.return_value = True
retry_policy.seconds_to_sleep.return_value = 0.01
- with patch('requests.Session.post') as patched_post:
+ with patch("requests.Session.post") as patched_post:
# When we call assert_equals in this unit test, the side_effect is executed.
# So, the last status_code should be repeated.
sequential_values = [500, 200, 200]
@@ -152,7 +159,7 @@ def test_post_will_retry():
@with_setup(_setup, _teardown)
def test_delete():
- with patch('requests.Session.delete') as patched_delete:
+ with patch("requests.Session.delete") as patched_delete:
type(patched_delete.return_value).status_code = 200
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -165,7 +172,7 @@ def test_delete():
@raises(HttpClientException)
@with_setup(_setup, _teardown)
def test_delete_throws():
- with patch('requests.Session.delete') as patched_delete:
+ with patch("requests.Session.delete") as patched_delete:
type(patched_delete.return_value).status_code = 500
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -180,7 +187,7 @@ def test_delete_will_retry():
retry_policy.should_retry.return_value = True
retry_policy.seconds_to_sleep.return_value = 0.01
- with patch('requests.Session.delete') as patched_delete:
+ with patch("requests.Session.delete") as patched_delete:
# When we call assert_equals in this unit test, the side_effect is executed.
# So, the last status_code should be repeated.
sequential_values = [500, 200, 200]
@@ -203,7 +210,7 @@ def test_will_retry_error_no():
retry_policy.should_retry.return_value = False
retry_policy.seconds_to_sleep.return_value = 0.01
- with patch('requests.Session.get') as patched_get:
+ with patch("requests.Session.get") as patched_get:
patched_get.side_effect = requests.exceptions.ConnectionError()
client = ReliableHttpClient(endpoint, {}, retry_policy)
@@ -219,8 +226,8 @@ def test_basic_auth_check_auth():
endpoint = Endpoint("http://url.com", basic_auth)
client = ReliableHttpClient(endpoint, {}, retry_policy)
assert isinstance(client._auth, Basic)
- assert hasattr(client._auth, 'username')
- assert hasattr(client._auth, 'password')
+ assert hasattr(client._auth, "username")
+ assert hasattr(client._auth, "password")
assert_equals(client._auth.username, endpoint.auth.username)
assert_equals(client._auth.password, endpoint.auth.password)
@@ -238,24 +245,21 @@ def test_kerberos_auth_check_auth():
client = ReliableHttpClient(endpoint, {}, retry_policy)
assert_is_not_none(client._auth)
assert isinstance(client._auth, HTTPKerberosAuth)
- assert hasattr(client._auth, 'mutual_authentication')
+ assert hasattr(client._auth, "mutual_authentication")
assert_equals(client._auth.mutual_authentication, REQUIRED)
@with_setup(_setup, _teardown)
def test_kerberos_auth_custom_configuration():
- custom_kerberos_conf = {
- "mutual_authentication": OPTIONAL,
- "force_preemptive": True
- }
- overrides = { conf.kerberos_auth_configuration.__name__: custom_kerberos_conf }
+ custom_kerberos_conf = {"mutual_authentication": OPTIONAL, "force_preemptive": True}
+ overrides = {conf.kerberos_auth_configuration.__name__: custom_kerberos_conf}
conf.override_all(overrides)
kerberos_auth = Kerberos()
endpoint = Endpoint("http://url.com", kerberos_auth)
client = ReliableHttpClient(endpoint, {}, retry_policy)
assert_is_not_none(client._auth)
assert isinstance(client._auth, HTTPKerberosAuth)
- assert hasattr(client._auth, 'mutual_authentication')
+ assert hasattr(client._auth, "mutual_authentication")
assert_equals(client._auth.mutual_authentication, OPTIONAL)
- assert hasattr(client._auth, 'force_preemptive')
+ assert hasattr(client._auth, "force_preemptive")
assert_equals(client._auth.force_preemptive, True)
diff --git a/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py b/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py
index 396232990..cc62c27ca 100644
--- a/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py
+++ b/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py
@@ -3,7 +3,12 @@
import sparkmagic.utils.configuration as conf
from sparkmagic.utils.utils import parse_argstring_or_throw, initialize_auth
-from sparkmagic.utils.constants import EXPECTED_ERROR_MSG, MIMETYPE_TEXT_PLAIN, NO_AUTH, AUTH_BASIC
+from sparkmagic.utils.constants import (
+ EXPECTED_ERROR_MSG,
+ MIMETYPE_TEXT_PLAIN,
+ NO_AUTH,
+ AUTH_BASIC,
+)
from sparkmagic.magics.remotesparkmagics import RemoteSparkMagics
from sparkmagic.livyclientlib.command import Command
from sparkmagic.livyclientlib.endpoint import Endpoint
@@ -51,20 +56,21 @@ def test_info_endpoint_command_parses():
magic.spark(command)
- print_info_mock.assert_called_once_with(None,1234)
+ print_info_mock.assert_called_once_with(None, 1234)
@with_setup(_setup, _teardown)
def test_info_command_exception():
- print_info_mock = MagicMock(side_effect=LivyClientTimeoutException('OHHHHHOHOHOHO'))
+ print_info_mock = MagicMock(side_effect=LivyClientTimeoutException("OHHHHHOHOHOHO"))
magic._print_local_info = print_info_mock
command = "info"
magic.spark(command)
print_info_mock.assert_called_once_with()
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(print_info_mock.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(print_info_mock.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -80,8 +86,12 @@ def test_add_sessions_command_parses():
magic.spark(line)
args = parse_argstring_or_throw(RemoteSparkMagics.spark, line)
- add_sessions_mock.assert_called_once_with("name", Endpoint("http://url.com", initialize_auth(args)),
- False, {"kind": "pyspark"})
+ add_sessions_mock.assert_called_once_with(
+ "name",
+ Endpoint("http://url.com", initialize_auth(args)),
+ False,
+ {"kind": "pyspark"},
+ )
# Skip and scala - upper case
add_sessions_mock = MagicMock()
spark_controller.add_session = add_sessions_mock
@@ -94,8 +104,12 @@ def test_add_sessions_command_parses():
magic.spark(line)
args = parse_argstring_or_throw(RemoteSparkMagics.spark, line)
args.auth = NO_AUTH
- add_sessions_mock.assert_called_once_with("name", Endpoint("http://location:port", initialize_auth(args)),
- True, {"kind": "spark"})
+ add_sessions_mock.assert_called_once_with(
+ "name",
+ Endpoint("http://location:port", initialize_auth(args)),
+ True,
+ {"kind": "spark"},
+ )
@with_setup(_setup, _teardown)
@@ -106,21 +120,25 @@ def test_add_sessions_command_parses_kerberos():
command = "add"
name = "-s name"
language = "-l python"
- connection_string = "-u http://url.com -t {}".format('Kerberos')
+ connection_string = "-u http://url.com -t {}".format("Kerberos")
line = " ".join([command, name, language, connection_string])
magic.spark(line)
args = parse_argstring_or_throw(RemoteSparkMagics.spark, line)
auth_instance = initialize_auth(args)
-
- add_sessions_mock.assert_called_once_with("name", Endpoint("http://url.com", initialize_auth(args)),
- False, {"kind": "pyspark"})
+
+ add_sessions_mock.assert_called_once_with(
+ "name",
+ Endpoint("http://url.com", initialize_auth(args)),
+ False,
+ {"kind": "pyspark"},
+ )
assert_equals(auth_instance.url, "http://url.com")
@with_setup(_setup, _teardown)
def test_add_sessions_command_exception():
# Do not skip and python
- add_sessions_mock = MagicMock(side_effect=BadUserDataException('hehe'))
+ add_sessions_mock = MagicMock(side_effect=BadUserDataException("hehe"))
spark_controller.add_session = add_sessions_mock
command = "add"
name = "-s name"
@@ -130,16 +148,21 @@ def test_add_sessions_command_exception():
magic.spark(line)
args = parse_argstring_or_throw(RemoteSparkMagics.spark, line)
- add_sessions_mock.assert_called_once_with("name", Endpoint("http://url.com", initialize_auth(args)),
- False, {"kind": "pyspark"})
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(add_sessions_mock.side_effect))
+ add_sessions_mock.assert_called_once_with(
+ "name",
+ Endpoint("http://url.com", initialize_auth(args)),
+ False,
+ {"kind": "pyspark"},
+ )
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(add_sessions_mock.side_effect)
+ )
@with_setup(_setup, _teardown)
def test_add_sessions_command_extra_properties():
conf.override_all({})
- magic.spark("config", "{\"extra\": \"yes\"}")
+ magic.spark("config", '{"extra": "yes"}')
assert conf.session_configs() == {"extra": "yes"}
add_sessions_mock = MagicMock()
@@ -153,8 +176,12 @@ def test_add_sessions_command_extra_properties():
magic.spark(line)
args = parse_argstring_or_throw(RemoteSparkMagics.spark, line)
args.auth = NO_AUTH
- add_sessions_mock.assert_called_once_with("name", Endpoint("http://livyendpoint.com", initialize_auth(args)),
- False, {"kind": "spark", "extra": "yes"})
+ add_sessions_mock.assert_called_once_with(
+ "name",
+ Endpoint("http://livyendpoint.com", initialize_auth(args)),
+ False,
+ {"kind": "spark", "extra": "yes"},
+ )
conf.override_all({})
@@ -176,13 +203,14 @@ def test_delete_sessions_command_parses():
@with_setup(_setup, _teardown)
def test_delete_sessions_command_exception():
- mock_method = MagicMock(side_effect=LivyUnexpectedStatusException('FEEEEEELINGS'))
+ mock_method = MagicMock(side_effect=LivyUnexpectedStatusException("FEEEEEELINGS"))
spark_controller.delete_session_by_name = mock_method
command = "delete -s name"
magic.spark(command)
mock_method.assert_called_once_with("name")
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(mock_method.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(mock_method.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -198,14 +226,17 @@ def test_cleanup_command_parses():
@with_setup(_setup, _teardown)
def test_cleanup_command_exception():
- mock_method = MagicMock(side_effect=SessionManagementException('Livy did something VERY BAD'))
+ mock_method = MagicMock(
+ side_effect=SessionManagementException("Livy did something VERY BAD")
+ )
spark_controller.cleanup = mock_method
line = "cleanup"
magic.spark(line)
mock_method.assert_called_once_with()
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(mock_method.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(mock_method.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -232,7 +263,9 @@ def test_bad_command_writes_error():
magic.spark(line)
- ipython_display.send_error.assert_called_once_with("Subcommand '{}' not found. {}".format(line, usage))
+ ipython_display.send_error.assert_called_once_with(
+ "Subcommand '{}' not found. {}".format(line, usage)
+ )
@with_setup(_setup, _teardown)
@@ -270,13 +303,15 @@ def test_run_cell_command_writes_to_err():
run_cell_method.assert_called_once_with(Command(cell), name)
assert result is None
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG.format(result_value))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(result_value)
+ )
@with_setup(_setup, _teardown)
def test_run_cell_command_exception():
run_cell_method = MagicMock()
- run_cell_method.side_effect = HttpClientException('meh')
+ run_cell_method.side_effect = HttpClientException("meh")
spark_controller.run_command = run_cell_method
command = "-s"
@@ -288,8 +323,9 @@ def test_run_cell_command_exception():
run_cell_method.assert_called_once_with(Command(cell), name)
assert result is None
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(run_cell_method.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(run_cell_method.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -307,8 +343,9 @@ def test_run_spark_command_parses():
result = magic.spark(line, cell)
- magic.execute_spark.assert_called_once_with("cell code",
- None, "sample", None, None, "sessions_name", None)
+ magic.execute_spark.assert_called_once_with(
+ "cell code", None, "sample", None, None, "sessions_name", None
+ )
@with_setup(_setup, _teardown)
@@ -323,13 +360,16 @@ def test_run_spark_command_parses_with_coerce():
method_name = "sample"
coer = "--coerce"
coerce_value = "True"
- line = " ".join([command, name, context, context_name, meth, method_name, coer, coerce_value])
+ line = " ".join(
+ [command, name, context, context_name, meth, method_name, coer, coerce_value]
+ )
cell = "cell code"
result = magic.spark(line, cell)
- magic.execute_spark.assert_called_once_with("cell code",
- None, "sample", None, None, "sessions_name", True)
+ magic.execute_spark.assert_called_once_with(
+ "cell code", None, "sample", None, None, "sessions_name", True
+ )
@with_setup(_setup, _teardown)
@@ -344,13 +384,16 @@ def test_run_spark_command_parses_with_coerce_false():
method_name = "sample"
coer = "--coerce"
coerce_value = "False"
- line = " ".join([command, name, context, context_name, meth, method_name, coer, coerce_value])
+ line = " ".join(
+ [command, name, context, context_name, meth, method_name, coer, coerce_value]
+ )
cell = "cell code"
result = magic.spark(line, cell)
- magic.execute_spark.assert_called_once_with("cell code",
- None, "sample", None, None, "sessions_name", False)
+ magic.execute_spark.assert_called_once_with(
+ "cell code", None, "sample", None, None, "sessions_name", False
+ )
@with_setup(_setup, _teardown)
@@ -365,13 +408,16 @@ def test_run_sql_command_parses_with_coerce_false():
method_name = "sample"
coer = "--coerce"
coerce_value = "False"
- line = " ".join([command, name, context, context_name, meth, method_name, coer, coerce_value])
+ line = " ".join(
+ [command, name, context, context_name, meth, method_name, coer, coerce_value]
+ )
cell = "cell code"
result = magic.spark(line, cell)
- magic.execute_sqlquery.assert_called_once_with("cell code",
- "sample", None, None, "sessions_name", None, False, False)
+ magic.execute_sqlquery.assert_called_once_with(
+ "cell code", "sample", None, None, "sessions_name", None, False, False
+ )
@with_setup(_setup, _teardown)
@@ -386,12 +432,16 @@ def test_run_spark_with_store_command_parses():
method_name = "sample"
output = "-o"
output_var = "var_name"
- line = " ".join([command, name, context, context_name, meth, method_name, output, output_var])
+ line = " ".join(
+ [command, name, context, context_name, meth, method_name, output, output_var]
+ )
cell = "cell code"
result = magic.spark(line, cell)
- magic.execute_spark.assert_called_once_with("cell code",
- "var_name", "sample", None, None, "sessions_name", None)
+ magic.execute_spark.assert_called_once_with(
+ "cell code", "var_name", "sample", None, None, "sessions_name", None
+ )
+
@with_setup(_setup, _teardown)
def test_run_spark_with_store_correct_calls():
@@ -409,19 +459,34 @@ def test_run_spark_with_store_correct_calls():
output_var = "var_name"
coer = "--coerce"
coerce_value = "True"
- line = " ".join([command, name, context, context_name, meth, method_name, output, output_var, coer, coerce_value])
+ line = " ".join(
+ [
+ command,
+ name,
+ context,
+ context_name,
+ meth,
+ method_name,
+ output,
+ output_var,
+ coer,
+ coerce_value,
+ ]
+ )
cell = "cell code"
result = magic.spark(line, cell)
run_cell_method.assert_any_call(Command(cell), name)
- run_cell_method.assert_any_call(SparkStoreCommand(output_var, samplemethod=method_name, coerce=True), name)
+ run_cell_method.assert_any_call(
+ SparkStoreCommand(output_var, samplemethod=method_name, coerce=True), name
+ )
@with_setup(_setup, _teardown)
def test_run_spark_command_exception():
run_cell_method = MagicMock()
- run_cell_method.side_effect = LivyUnexpectedStatusException('WOW')
+ run_cell_method.side_effect = LivyUnexpectedStatusException("WOW")
spark_controller.run_command = run_cell_method
command = "-s"
@@ -432,19 +497,23 @@ def test_run_spark_command_exception():
method_name = "sample"
output = "-o"
output_var = "var_name"
- line = " ".join([command, name, context, context_name, meth, method_name, output, output_var])
+ line = " ".join(
+ [command, name, context, context_name, meth, method_name, output, output_var]
+ )
cell = "cell code"
result = magic.spark(line, cell)
run_cell_method.assert_any_call(Command(cell), name)
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(run_cell_method.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(run_cell_method.side_effect)
+ )
+
@with_setup(_setup, _teardown)
def test_run_spark_command_exception_while_storing():
run_cell_method = MagicMock()
- exception = LivyUnexpectedStatusException('WOW')
+ exception = LivyUnexpectedStatusException("WOW")
run_cell_method.side_effect = [(True, "", MIMETYPE_TEXT_PLAIN), exception]
spark_controller.run_command = run_cell_method
@@ -456,17 +525,21 @@ def test_run_spark_command_exception_while_storing():
method_name = "sample"
output = "-o"
output_var = "var_name"
- line = " ".join([command, name, context, context_name, meth, method_name, output, output_var])
+ line = " ".join(
+ [command, name, context, context_name, meth, method_name, output, output_var]
+ )
cell = "cell code"
result = magic.spark(line, cell)
run_cell_method.assert_any_call(Command(cell), name)
- run_cell_method.assert_any_call(SparkStoreCommand(output_var, samplemethod=method_name), name)
+ run_cell_method.assert_any_call(
+ SparkStoreCommand(output_var, samplemethod=method_name), name
+ )
ipython_display.write.assert_called_once_with("")
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(exception))
-
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(exception)
+ )
@with_setup(_setup, _teardown)
@@ -486,14 +559,16 @@ def test_run_sql_command_parses():
result = magic.spark(line, cell)
- run_cell_method.assert_called_once_with(SQLQuery(cell, samplemethod=method_name), name)
+ run_cell_method.assert_called_once_with(
+ SQLQuery(cell, samplemethod=method_name), name
+ )
assert result is not None
@with_setup(_setup, _teardown)
def test_run_sql_command_exception():
run_cell_method = MagicMock()
- run_cell_method.side_effect = LivyUnexpectedStatusException('WOW')
+ run_cell_method.side_effect = LivyUnexpectedStatusException("WOW")
spark_controller.run_sqlquery = run_cell_method
command = "-s"
@@ -507,9 +582,12 @@ def test_run_sql_command_exception():
result = magic.spark(line, cell)
- run_cell_method.assert_called_once_with(SQLQuery(cell, samplemethod=method_name), name)
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(run_cell_method.side_effect))
+ run_cell_method.assert_called_once_with(
+ SQLQuery(cell, samplemethod=method_name), name
+ )
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(run_cell_method.side_effect)
+ )
@with_setup(_setup, _teardown)
@@ -530,7 +608,9 @@ def test_run_sql_command_knows_how_to_be_quiet():
result = magic.spark(line, cell)
- run_cell_method.assert_called_once_with(SQLQuery(cell, samplemethod=method_name), name)
+ run_cell_method.assert_called_once_with(
+ SQLQuery(cell, samplemethod=method_name), name
+ )
assert result is None
@@ -555,7 +635,9 @@ def test_logs_subcommand():
@with_setup(_setup, _teardown)
def test_logs_exception():
- get_logs_method = MagicMock(side_effect=LivyUnexpectedStatusException('How did this happen?'))
+ get_logs_method = MagicMock(
+ side_effect=LivyUnexpectedStatusException("How did this happen?")
+ )
result_value = ""
get_logs_method.return_value = result_value
spark_controller.get_logs = get_logs_method
@@ -569,5 +651,6 @@ def test_logs_exception():
get_logs_method.assert_called_once_with(name)
assert result is None
- ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG
- .format(get_logs_method.side_effect))
+ ipython_display.send_error.assert_called_once_with(
+ EXPECTED_ERROR_MSG.format(get_logs_method.side_effect)
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py b/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py
index f303e369b..afae91935 100644
--- a/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py
+++ b/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py
@@ -6,122 +6,211 @@
from nose.tools import assert_raises, assert_equals
from sparkmagic.livyclientlib.command import Command
import sparkmagic.utils.constants as constants
-from sparkmagic.livyclientlib.sendpandasdftosparkcommand import SendPandasDfToSparkCommand
+from sparkmagic.livyclientlib.sendpandasdftosparkcommand import (
+ SendPandasDfToSparkCommand,
+)
+
def test_send_to_scala():
- input_variable_name = 'input'
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
- output_variable_name = 'output'
+ input_variable_name = "input"
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
+ output_variable_name = "output"
maxrows = 1
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, maxrows)
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, maxrows
+ )
sparkcommand._scala_command = MagicMock(return_value=MagicMock())
- sparkcommand.to_command(constants.SESSION_KIND_SPARK, input_variable_name, input_variable_value, output_variable_name)
- sparkcommand._scala_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand.to_command(
+ constants.SESSION_KIND_SPARK,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+ sparkcommand._scala_command.assert_called_with(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+
def test_send_to_r():
- input_variable_name = 'input'
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
- output_variable_name = 'output'
+ input_variable_name = "input"
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
+ output_variable_name = "output"
maxrows = 1
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, maxrows)
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, maxrows
+ )
sparkcommand._r_command = MagicMock(return_value=MagicMock())
- sparkcommand.to_command(constants.SESSION_KIND_SPARKR, input_variable_name, input_variable_value, output_variable_name)
- sparkcommand._r_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand.to_command(
+ constants.SESSION_KIND_SPARKR,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+ sparkcommand._r_command.assert_called_with(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+
def test_send_to_python():
- input_variable_name = 'input'
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
- output_variable_name = 'output'
+ input_variable_name = "input"
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
+ output_variable_name = "output"
maxrows = 1
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, maxrows)
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, maxrows
+ )
sparkcommand._pyspark_command = MagicMock(return_value=MagicMock())
- sparkcommand.to_command(constants.SESSION_KIND_PYSPARK, input_variable_name, input_variable_value, output_variable_name)
- sparkcommand._pyspark_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand.to_command(
+ constants.SESSION_KIND_PYSPARK,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+ sparkcommand._pyspark_command.assert_called_with(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+
def test_should_create_a_valid_scala_expression():
input_variable_name = "input"
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
output_variable_name = "output"
- pandas_df_jsonized = u'''[{"A":1,"B":2}]'''
- expected_scala_code = u'''
+ pandas_df_jsonized = """[{"A":1,"B":2}]"""
+ expected_scala_code = '''
val rdd_json_array = spark.sparkContext.makeRDD("""{}""" :: Nil)
- val {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name)
+ val {} = spark.read.json(rdd_json_array)'''.format(
+ pandas_df_jsonized, output_variable_name
+ )
+
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, 1
+ )
+ assert_equals(
+ sparkcommand._scala_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command(expected_scala_code),
+ )
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1)
- assert_equals(sparkcommand._scala_command(input_variable_name, input_variable_value, output_variable_name),
- Command(expected_scala_code))
def test_should_create_a_valid_r_expression():
input_variable_name = "input"
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
output_variable_name = "output"
- pandas_df_jsonized = u'''[{"A":1,"B":2}]'''
- expected_r_code = u'''
+ pandas_df_jsonized = """[{"A":1,"B":2}]"""
+ expected_r_code = """
fileConn<-file("temporary_pandas_df_sparkmagics.txt")
writeLines('{}', fileConn)
close(fileConn)
{} <- read.json("temporary_pandas_df_sparkmagics.txt")
{}.persist()
- file.remove("temporary_pandas_df_sparkmagics.txt")'''.format(pandas_df_jsonized, output_variable_name, output_variable_name)
+ file.remove("temporary_pandas_df_sparkmagics.txt")""".format(
+ pandas_df_jsonized, output_variable_name, output_variable_name
+ )
+
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, 1
+ )
+ assert_equals(
+ sparkcommand._r_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command(expected_r_code),
+ )
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1)
- assert_equals(sparkcommand._r_command(input_variable_name, input_variable_value, output_variable_name),
- Command(expected_r_code))
def test_should_create_a_valid_python3_expression():
input_variable_name = "input"
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
output_variable_name = "output"
- pandas_df_jsonized = u'''[{"A":1,"B":2}]'''
+ pandas_df_jsonized = """[{"A":1,"B":2}]"""
expected_python3_code = SendPandasDfToSparkCommand._python_decode
- expected_python3_code += u'''
+ expected_python3_code += """
json_array = json_loads_byteified('{}')
rdd_json_array = spark.sparkContext.parallelize(json_array)
- {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name)
+ {} = spark.read.json(rdd_json_array)""".format(
+ pandas_df_jsonized, output_variable_name
+ )
+
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, 1
+ )
+ assert_equals(
+ sparkcommand._pyspark_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command(expected_python3_code),
+ )
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1)
- assert_equals(sparkcommand._pyspark_command(input_variable_name, input_variable_value, output_variable_name),
- Command(expected_python3_code))
def test_should_create_a_valid_python2_expression():
input_variable_name = "input"
- input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]})
+ input_variable_value = pd.DataFrame({"A": [1], "B": [2]})
output_variable_name = "output"
- pandas_df_jsonized = u'''[{"A":1,"B":2}]'''
+ pandas_df_jsonized = """[{"A":1,"B":2}]"""
expected_python2_code = SendPandasDfToSparkCommand._python_decode
- expected_python2_code += u'''
+ expected_python2_code += """
json_array = json_loads_byteified('{}')
rdd_json_array = spark.sparkContext.parallelize(json_array)
- {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name)
+ {} = spark.read.json(rdd_json_array)""".format(
+ pandas_df_jsonized, output_variable_name
+ )
+
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, 1
+ )
+ assert_equals(
+ sparkcommand._pyspark_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command(expected_python2_code),
+ )
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1)
- assert_equals(sparkcommand._pyspark_command(input_variable_name, input_variable_value, output_variable_name),
- Command(expected_python2_code))
def test_should_properly_limit_pandas_dataframe():
input_variable_name = "input"
max_rows = 1
- input_variable_value = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B' : [5, 6, 7, 8, 9]})
+ input_variable_value = pd.DataFrame({"A": [0, 1, 2, 3, 4], "B": [5, 6, 7, 8, 9]})
output_variable_name = "output"
- pandas_df_jsonized = u'''[{"A":0,"B":5}]''' #notice we expect json to have dropped all but one row
- expected_scala_code = u'''
+ pandas_df_jsonized = (
+ """[{"A":0,"B":5}]""" # notice we expect json to have dropped all but one row
+ )
+ expected_scala_code = '''
val rdd_json_array = spark.sparkContext.makeRDD("""{}""" :: Nil)
- val {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name)
+ val {} = spark.read.json(rdd_json_array)'''.format(
+ pandas_df_jsonized, output_variable_name
+ )
+
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, max_rows
+ )
+ assert_equals(
+ sparkcommand._scala_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command(expected_scala_code),
+ )
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, max_rows)
- assert_equals(sparkcommand._scala_command(input_variable_name, input_variable_value, output_variable_name),
- Command(expected_scala_code))
def test_should_raise_when_input_is_not_pandas_df():
input_variable_name = "input"
input_variable_value = "not a pandas dataframe"
output_variable_name = "output"
- sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1)
- assert_raises(BadUserDataException, sparkcommand.to_command, "spark", input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand = SendPandasDfToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name, 1
+ )
+ assert_raises(
+ BadUserDataException,
+ sparkcommand.to_command,
+ "spark",
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py b/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py
index 3d36ceb0c..a0e5c9ef0 100644
--- a/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py
+++ b/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py
@@ -6,67 +6,140 @@
from sparkmagic.livyclientlib.command import Command
import sparkmagic.utils.constants as constants
+
def test_send_to_scala():
input_variable_name = "input"
input_variable_value = "value"
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
sparkcommand._scala_command = MagicMock(return_value=MagicMock())
- sparkcommand.to_command(constants.SESSION_KIND_SPARK, input_variable_name, input_variable_value, output_variable_name)
- sparkcommand._scala_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand.to_command(
+ constants.SESSION_KIND_SPARK,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+ sparkcommand._scala_command.assert_called_with(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+
def test_send_to_r():
input_variable_name = "input"
input_variable_value = "value"
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
sparkcommand._r_command = MagicMock(return_value=MagicMock())
- sparkcommand.to_command(constants.SESSION_KIND_SPARKR, input_variable_name, input_variable_value, output_variable_name)
- sparkcommand._r_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand.to_command(
+ constants.SESSION_KIND_SPARKR,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+ sparkcommand._r_command.assert_called_with(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+
def test_send_to_pyspark():
input_variable_name = "input"
input_variable_value = "value"
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
sparkcommand._pyspark_command = MagicMock(return_value=MagicMock())
- sparkcommand.to_command(constants.SESSION_KIND_PYSPARK, input_variable_name, input_variable_value, output_variable_name)
- sparkcommand._pyspark_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand.to_command(
+ constants.SESSION_KIND_PYSPARK,
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+ sparkcommand._pyspark_command.assert_called_with(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+
def test_to_command_invalid():
input_variable_name = "input"
input_variable_value = 42
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
- assert_raises(BadUserDataException, sparkcommand.to_command, "invalid", input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+ assert_raises(
+ BadUserDataException,
+ sparkcommand.to_command,
+ "invalid",
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+
def test_should_raise_when_input_aint_a_string():
input_variable_name = "input"
input_variable_value = 42
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
- assert_raises(BadUserDataException, sparkcommand.to_command, "spark", input_variable_name, input_variable_value, output_variable_name)
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+ assert_raises(
+ BadUserDataException,
+ sparkcommand.to_command,
+ "spark",
+ input_variable_name,
+ input_variable_value,
+ output_variable_name,
+ )
+
def test_should_create_a_valid_scala_expression():
input_variable_name = "input"
input_variable_value = "value"
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
- assert_equals(sparkcommand._scala_command(input_variable_name, input_variable_value, output_variable_name),
- Command(u'var {} = """{}"""'.format(output_variable_name, input_variable_value)))
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+ assert_equals(
+ sparkcommand._scala_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command('var {} = """{}"""'.format(output_variable_name, input_variable_value)),
+ )
+
def test_should_create_a_valid_python_expression():
input_variable_name = "input"
input_variable_value = "value"
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
- assert_equals(sparkcommand._pyspark_command(input_variable_name, input_variable_value, output_variable_name),
- Command(u'{} = {}'.format(output_variable_name, repr(input_variable_value))))
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+ assert_equals(
+ sparkcommand._pyspark_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command("{} = {}".format(output_variable_name, repr(input_variable_value))),
+ )
+
def test_should_create_a_valid_r_expression():
input_variable_name = "input"
input_variable_value = "value"
output_variable_name = "output"
- sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name)
- assert_equals(sparkcommand._r_command(input_variable_name, input_variable_value, output_variable_name),
- Command(u'''assign("{}","{}")'''.format(output_variable_name, input_variable_value)))
+ sparkcommand = SendStringToSparkCommand(
+ input_variable_name, input_variable_value, output_variable_name
+ )
+ assert_equals(
+ sparkcommand._r_command(
+ input_variable_name, input_variable_value, output_variable_name
+ ),
+ Command(
+ """assign("{}","{}")""".format(output_variable_name, input_variable_value)
+ ),
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_sessionmanager.py b/sparkmagic/sparkmagic/tests/test_sessionmanager.py
index 1a010ae41..eb666c2d6 100644
--- a/sparkmagic/sparkmagic/tests/test_sessionmanager.py
+++ b/sparkmagic/sparkmagic/tests/test_sessionmanager.py
@@ -111,7 +111,9 @@ def test_cleanup_all_sessions_on_exit():
client0.delete.assert_called_once_with()
client1.delete.assert_called_once_with()
- manager.ipython_display.writeln.assert_called_once_with(u"Cleaning up livy sessions on exit is enabled")
+ manager.ipython_display.writeln.assert_called_once_with(
+ "Cleaning up livy sessions on exit is enabled"
+ )
def test_cleanup_all_sessions_on_exit_fails():
@@ -121,7 +123,7 @@ def test_cleanup_all_sessions_on_exit_fails():
conf.override(conf.cleanup_all_sessions_on_exit.__name__, True)
client0 = MagicMock()
client1 = MagicMock()
- client0.delete.side_effect = Exception('Mocked exception for client1.delete')
+ client0.delete.side_effect = Exception("Mocked exception for client1.delete")
manager = get_session_manager()
manager.add_session("name0", client0)
manager.add_session("name1", client1)
@@ -150,7 +152,7 @@ def test_get_session_name_by_id_endpoint():
name = manager.get_session_name_by_id_endpoint(id_to_search, endpoint_to_search)
assert_equals(None, name)
-
+
session = MagicMock()
type(session).id = PropertyMock(return_value=int(id_to_search))
session.endpoint = endpoint_to_search
diff --git a/sparkmagic/sparkmagic/tests/test_sparkcontroller.py b/sparkmagic/sparkmagic/tests/test_sparkcontroller.py
index fa3da4263..f877db103 100644
--- a/sparkmagic/sparkmagic/tests/test_sparkcontroller.py
+++ b/sparkmagic/sparkmagic/tests/test_sparkcontroller.py
@@ -3,7 +3,10 @@
from mock import MagicMock
from nose.tools import with_setup, assert_equals, raises
from sparkmagic.livyclientlib.endpoint import Endpoint
-from sparkmagic.livyclientlib.exceptions import SessionManagementException, HttpClientException
+from sparkmagic.livyclientlib.exceptions import (
+ SessionManagementException,
+ HttpClientException,
+)
from sparkmagic.livyclientlib.sparkcontroller import SparkController
client_manager = None
@@ -34,9 +37,11 @@ def _setup():
controller.session_manager = client_manager
controller.spark_events = spark_events
+
def _teardown():
pass
+
@with_setup(_setup, _teardown)
def test_add_session():
name = "name"
@@ -49,7 +54,9 @@ def test_add_session():
controller.add_session(name, endpoint, False, properties)
- controller._livy_session.assert_called_once_with(controller._http_client.return_value, properties, ipython_display)
+ controller._livy_session.assert_called_once_with(
+ controller._http_client.return_value, properties, ipython_display
+ )
controller.session_manager.add_session.assert_called_once_with(name, session)
session.start.assert_called_once()
@@ -129,10 +136,12 @@ def test_get_client_keys():
@with_setup(_setup, _teardown)
def test_get_all_sessions():
http_client = MagicMock()
- http_client.get_sessions.return_value = json.loads('{"from":0,"total":3,"sessions":[{"id":0,"state":"idle","kind":'
- '"spark","log":[""]}, {"id":1,"state":"busy","kind":"spark","log"'
- ':[""]},{"id":2,"state":"busy","kind":"sql","log"'
- ':[""]}]}')
+ http_client.get_sessions.return_value = json.loads(
+ '{"from":0,"total":3,"sessions":[{"id":0,"state":"idle","kind":'
+ '"spark","log":[""]}, {"id":1,"state":"busy","kind":"spark","log"'
+ ':[""]},{"id":2,"state":"busy","kind":"sql","log"'
+ ':[""]}]}'
+ )
controller._http_client = MagicMock(return_value=http_client)
controller._livy_session = MagicMock()
@@ -156,21 +165,27 @@ def test_cleanup_endpoint():
@with_setup(_setup, _teardown)
def test_delete_session_by_id_existent_non_managed():
http_client = MagicMock()
- http_client.get_session.return_value = json.loads('{"id":0,"state":"starting","kind":"spark","log":[]}')
+ http_client.get_session.return_value = json.loads(
+ '{"id":0,"state":"starting","kind":"spark","log":[]}'
+ )
controller._http_client = MagicMock(return_value=http_client)
session = MagicMock()
controller._livy_session = MagicMock(return_value=session)
controller.delete_session_by_id("conn_str", 0)
- controller._livy_session.assert_called_once_with(http_client, {"kind": "spark"}, ipython_display, 0)
+ controller._livy_session.assert_called_once_with(
+ http_client, {"kind": "spark"}, ipython_display, 0
+ )
session.delete.assert_called_once_with()
@with_setup(_setup, _teardown)
def test_delete_session_by_id_existent_managed():
name = "name"
- controller.session_manager.get_session_name_by_id_endpoint = MagicMock(return_value=name)
+ controller.session_manager.get_session_name_by_id_endpoint = MagicMock(
+ return_value=name
+ )
controller.session_manager.get_sessions_list = MagicMock(return_value=[name])
controller.delete_session_by_name = MagicMock()
@@ -190,6 +205,7 @@ def test_delete_session_by_id_non_existent():
controller.delete_session_by_id("conn_str", 0)
+
@with_setup(_setup, _teardown)
def test_get_app_id():
chosen_client = MagicMock()
@@ -199,6 +215,7 @@ def test_get_app_id():
assert_equals(result, chosen_client.get_app_id.return_value)
chosen_client.get_app_id.assert_called_with()
+
@with_setup(_setup, _teardown)
def test_get_driver_log():
chosen_client = MagicMock()
@@ -208,6 +225,7 @@ def test_get_driver_log():
assert_equals(result, chosen_client.get_driver_log_url.return_value)
chosen_client.get_driver_log_url.assert_called_with()
+
@with_setup(_setup, _teardown)
def test_get_logs():
chosen_client = MagicMock()
@@ -217,19 +235,24 @@ def test_get_logs():
assert_equals(result, chosen_client.get_logs.return_value)
chosen_client.get_logs.assert_called_with()
+
@with_setup(_setup, _teardown)
@raises(SessionManagementException)
def test_get_logs_error():
chosen_client = MagicMock()
- controller.get_session_by_name_or_default = MagicMock(side_effect=SessionManagementException('THERE WAS A SPOOKY GHOST'))
+ controller.get_session_by_name_or_default = MagicMock(
+ side_effect=SessionManagementException("THERE WAS A SPOOKY GHOST")
+ )
result = controller.get_logs()
+
@with_setup(_setup, _teardown)
def test_get_session_id_for_client():
assert controller.get_session_id_for_client("name") is not None
client_manager.get_session_id_for_client.assert_called_once_with("name")
+
@with_setup(_setup, _teardown)
def test_get_spark_ui_url():
chosen_client = MagicMock()
@@ -239,6 +262,7 @@ def test_get_spark_ui_url():
assert_equals(result, chosen_client.get_spark_ui_url.return_value)
chosen_client.get_spark_ui_url.assert_called_with()
+
@with_setup(_setup, _teardown)
def test_add_session_throws_when_session_start_fails():
name = "name"
@@ -259,18 +283,22 @@ def test_add_session_throws_when_session_start_fails():
session.start.assert_called_once()
controller.session_manager.add_session.assert_not_called
+
@with_setup(_setup, _teardown)
def test_add_session_cleanup_when_timeouts_and_session_posted_to_livy():
pass
+
@with_setup(_setup, _teardown)
def test_add_session_cleanup_when_timeouts_and_session_posted_to_livy():
_do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy=True)
+
@with_setup(_setup, _teardown)
def test_add_session_cleanup_when_timeouts_and_session_not_posted_to_livy():
_do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy=False)
+
def _do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy):
name = "name"
properties = {"kind": "spark"}
@@ -300,6 +328,7 @@ def _do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy):
else:
session.delete.assert_not_called()
+
@with_setup(_setup, _teardown)
def test_add_session_cleanup_when_session_delete_throws():
name = "name"
diff --git a/sparkmagic/sparkmagic/tests/test_sparkevents.py b/sparkmagic/sparkmagic/tests/test_sparkevents.py
index 340c378b4..116d2b7de 100644
--- a/sparkmagic/sparkmagic/tests/test_sparkevents.py
+++ b/sparkmagic/sparkmagic/tests/test_sparkevents.py
@@ -27,9 +27,11 @@ def _teardown():
@with_setup(_setup, _teardown)
def test_emit_library_loaded_event():
event_name = constants.LIBRARY_LOADED_EVENT
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ ]
spark_events.emit_library_loaded_event()
@@ -42,13 +44,15 @@ def test_emit_cluster_change_event():
status_code = 200
event_name = constants.CLUSTER_CHANGE_EVENT
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.CLUSTER_DNS_NAME, guid1),
- (constants.STATUS_CODE, status_code),
- (constants.SUCCESS, True),
- (constants.ERROR_MESSAGE, None)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.CLUSTER_DNS_NAME, guid1),
+ (constants.STATUS_CODE, status_code),
+ (constants.SUCCESS, True),
+ (constants.ERROR_MESSAGE, None),
+ ]
spark_events.emit_cluster_change_event(guid1, status_code, True, None)
@@ -60,11 +64,13 @@ def test_emit_cluster_change_event():
def test_emit_session_creation_start_event():
language = constants.SESSION_KIND_SPARK
event_name = constants.SESSION_CREATION_START_EVENT
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ ]
spark_events.emit_session_creation_start_event(guid1, language)
@@ -79,18 +85,22 @@ def test_emit_session_creation_end_event():
event_name = constants.SESSION_CREATION_END_EVENT
status = constants.BUSY_SESSION_STATUS
session_id = 0
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.STATUS, status),
- (constants.SUCCESS, True),
- (constants.EXCEPTION_TYPE, ""),
- (constants.EXCEPTION_MESSAGE, "")]
-
- spark_events.emit_session_creation_end_event(guid1, language, session_id, status, True, "", "")
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.STATUS, status),
+ (constants.SUCCESS, True),
+ (constants.EXCEPTION_TYPE, ""),
+ (constants.EXCEPTION_MESSAGE, ""),
+ ]
+
+ spark_events.emit_session_creation_end_event(
+ guid1, language, session_id, status, True, "", ""
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -103,13 +113,15 @@ def test_emit_session_deletion_start_event():
event_name = constants.SESSION_DELETION_START_EVENT
status = constants.BUSY_SESSION_STATUS
session_id = 0
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.STATUS, status)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.STATUS, status),
+ ]
spark_events.emit_session_deletion_start_event(guid1, language, session_id, status)
@@ -124,18 +136,22 @@ def test_emit_session_deletion_end_event():
event_name = constants.SESSION_DELETION_END_EVENT
status = constants.BUSY_SESSION_STATUS
session_id = 0
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.STATUS, status),
- (constants.SUCCESS, True),
- (constants.EXCEPTION_TYPE, ""),
- (constants.EXCEPTION_MESSAGE, "")]
-
- spark_events.emit_session_deletion_end_event(guid1, language, session_id, status, True, "", "")
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.STATUS, status),
+ (constants.SUCCESS, True),
+ (constants.EXCEPTION_TYPE, ""),
+ (constants.EXCEPTION_MESSAGE, ""),
+ ]
+
+ spark_events.emit_session_deletion_end_event(
+ guid1, language, session_id, status, True, "", ""
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -148,15 +164,19 @@ def test_emit_statement_execution_start_event():
session_id = 7
event_name = constants.STATEMENT_EXECUTION_START_EVENT
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.STATEMENT_GUID, guid2)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.STATEMENT_GUID, guid2),
+ ]
- spark_events.emit_statement_execution_start_event(guid1, language, session_id, guid2)
+ spark_events.emit_statement_execution_start_event(
+ guid1, language, session_id, guid2
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -170,23 +190,33 @@ def test_emit_statement_execution_end_event():
statement_id = 400
event_name = constants.STATEMENT_EXECUTION_END_EVENT
success = True
- exception_type = ''
- exception_message = 'foo'
-
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.STATEMENT_GUID, guid2),
- (constants.STATEMENT_ID, statement_id),
- (constants.SUCCESS, success),
- (constants.EXCEPTION_TYPE, exception_type),
- (constants.EXCEPTION_MESSAGE, exception_message)]
-
- spark_events.emit_statement_execution_end_event(guid1, language, session_id, guid2, statement_id, success,
- exception_type, exception_message)
+ exception_type = ""
+ exception_message = "foo"
+
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.STATEMENT_GUID, guid2),
+ (constants.STATEMENT_ID, statement_id),
+ (constants.SUCCESS, success),
+ (constants.EXCEPTION_TYPE, exception_type),
+ (constants.EXCEPTION_MESSAGE, exception_message),
+ ]
+
+ spark_events.emit_statement_execution_end_event(
+ guid1,
+ language,
+ session_id,
+ guid2,
+ statement_id,
+ success,
+ exception_type,
+ exception_message,
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -198,23 +228,26 @@ def test_emit_sql_execution_start_event():
event_name = constants.SQL_EXECUTION_START_EVENT
session_id = 22
language = constants.SESSION_KIND_SPARK
- samplemethod = 'sample'
+ samplemethod = "sample"
maxrows = 12
samplefraction = 0.5
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.SQL_GUID, guid2),
- (constants.SAMPLE_METHOD, samplemethod),
- (constants.MAX_ROWS, maxrows),
- (constants.SAMPLE_FRACTION, samplefraction)]
-
- spark_events.emit_sql_execution_start_event(guid1, language, session_id, guid2, samplemethod,
- maxrows, samplefraction)
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.SQL_GUID, guid2),
+ (constants.SAMPLE_METHOD, samplemethod),
+ (constants.MAX_ROWS, maxrows),
+ (constants.SAMPLE_FRACTION, samplefraction),
+ ]
+
+ spark_events.emit_sql_execution_start_event(
+ guid1, language, session_id, guid2, samplemethod, maxrows, samplefraction
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -227,23 +260,33 @@ def test_emit_sql_execution_end_event():
session_id = 17
language = constants.SESSION_KIND_SPARK
success = False
- exception_type = 'ValueError'
- exception_message = 'You screwed up'
-
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.SESSION_GUID, guid1),
- (constants.LIVY_KIND, language),
- (constants.SESSION_ID, session_id),
- (constants.SQL_GUID, guid2),
- (constants.STATEMENT_GUID, guid3),
- (constants.SUCCESS, success),
- (constants.EXCEPTION_TYPE, exception_type),
- (constants.EXCEPTION_MESSAGE, exception_message)]
-
- spark_events.emit_sql_execution_end_event(guid1, language, session_id, guid2, guid3,
- success, exception_type, exception_message)
+ exception_type = "ValueError"
+ exception_message = "You screwed up"
+
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.SESSION_GUID, guid1),
+ (constants.LIVY_KIND, language),
+ (constants.SESSION_ID, session_id),
+ (constants.SQL_GUID, guid2),
+ (constants.STATEMENT_GUID, guid3),
+ (constants.SUCCESS, success),
+ (constants.EXCEPTION_TYPE, exception_type),
+ (constants.EXCEPTION_MESSAGE, exception_message),
+ ]
+
+ spark_events.emit_sql_execution_end_event(
+ guid1,
+ language,
+ session_id,
+ guid2,
+ guid3,
+ success,
+ exception_type,
+ exception_message,
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -253,15 +296,17 @@ def test_emit_sql_execution_end_event():
@with_setup(_setup, _teardown)
def test_emit_magic_execution_start_event():
event_name = constants.MAGIC_EXECUTION_START_EVENT
- magic_name = 'sql'
+ magic_name = "sql"
language = constants.SESSION_KIND_SPARKR
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.MAGIC_NAME, magic_name),
- (constants.LIVY_KIND, language),
- (constants.MAGIC_GUID, guid1)]
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.MAGIC_NAME, magic_name),
+ (constants.LIVY_KIND, language),
+ (constants.MAGIC_GUID, guid1),
+ ]
spark_events.emit_magic_execution_start_event(magic_name, language, guid1)
@@ -273,24 +318,27 @@ def test_emit_magic_execution_start_event():
@with_setup(_setup, _teardown)
def test_emit_magic_execution_end_event():
event_name = constants.MAGIC_EXECUTION_END_EVENT
- magic_name = 'sql'
+ magic_name = "sql"
language = constants.SESSION_KIND_SPARKR
success = True
- exception_type = ''
- exception_message = ''
-
- kwargs_list = [(INSTANCE_ID, get_instance_id()),
- (EVENT_NAME, event_name),
- (TIMESTAMP, time_stamp),
- (constants.MAGIC_NAME, magic_name),
- (constants.LIVY_KIND, language),
- (constants.MAGIC_GUID, guid1),
- (constants.SUCCESS, success),
- (constants.EXCEPTION_TYPE, exception_type),
- (constants.EXCEPTION_MESSAGE, exception_message)]
-
- spark_events.emit_magic_execution_end_event(magic_name, language, guid1, success,
- exception_type, exception_message)
+ exception_type = ""
+ exception_message = ""
+
+ kwargs_list = [
+ (INSTANCE_ID, get_instance_id()),
+ (EVENT_NAME, event_name),
+ (TIMESTAMP, time_stamp),
+ (constants.MAGIC_NAME, magic_name),
+ (constants.LIVY_KIND, language),
+ (constants.MAGIC_GUID, guid1),
+ (constants.SUCCESS, success),
+ (constants.EXCEPTION_TYPE, exception_type),
+ (constants.EXCEPTION_MESSAGE, exception_message),
+ ]
+
+ spark_events.emit_magic_execution_end_event(
+ magic_name, language, guid1, success, exception_type, exception_message
+ )
spark_events._verify_language_ok.assert_called_once_with(language)
spark_events.get_utc_date_time.assert_called_with()
@@ -304,4 +352,4 @@ def test_magic_verify_language_ok():
@raises(AssertionError)
def test_magic_verify_language_ok_error():
- SparkEvents()._verify_language_ok('NYARGLEBARGLE')
+ SparkEvents()._verify_language_ok("NYARGLEBARGLE")
diff --git a/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py b/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py
index f901c8e48..6d383b642 100644
--- a/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py
+++ b/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py
@@ -16,8 +16,9 @@
class TestSparkKernel(SparkKernelBase):
def __init__(self):
kwargs = {"testing": True}
- super(TestSparkKernel, self).__init__(None, None, None, None, None, LANG_PYTHON, user_code_parser,
- **kwargs)
+ super(TestSparkKernel, self).__init__(
+ None, None, None, None, None, LANG_PYTHON, user_code_parser, **kwargs
+ )
def _setup():
@@ -25,8 +26,9 @@ def _setup():
kernel = TestSparkKernel()
- kernel._execute_cell_for_user = execute_cell_mock = MagicMock(return_value={'test': 'ing', 'a': 'b',
- 'status': 'ok'})
+ kernel._execute_cell_for_user = execute_cell_mock = MagicMock(
+ return_value={"test": "ing", "a": "b", "status": "ok"}
+ )
kernel._do_shutdown_ipykernel = do_shutdown_mock = MagicMock()
kernel.ipython_display = ipython_display = MagicMock()
@@ -78,14 +80,16 @@ def test_execute_alerts_user_if_an_unexpected_error_happens():
@with_setup(_setup, _teardown)
def test_execute_throws_if_fatal_error_happens_for_execution():
# Verify that the kernel sends the error from Python execution's context to the user
- fatal_error = u"Error."
- message = "{}\nException details:\n\t\"{}\"".format(fatal_error, fatal_error)
+ fatal_error = "Error."
+ message = '{}\nException details:\n\t"{}"'.format(fatal_error, fatal_error)
reply_content = dict()
- reply_content[u"status"] = u"error"
- reply_content[u"evalue"] = fatal_error
+ reply_content["status"] = "error"
+ reply_content["evalue"] = fatal_error
execute_cell_mock.return_value = reply_content
- ret = kernel._execute_cell(code, False, shutdown_if_error=True, log_if_error=fatal_error)
+ ret = kernel._execute_cell(
+ code, False, shutdown_if_error=True, log_if_error=fatal_error
+ )
assert ret is execute_cell_mock.return_value
assert kernel._fatal_error == message
@@ -118,55 +122,80 @@ def test_shutdown_cleans_up():
def test_register_auto_viz():
kernel._register_auto_viz()
- assert call("from autovizwidget.widget.utils import display_dataframe\nip = get_ipython()\nip.display_formatter"
- ".ipython_display_formatter.for_type_by_name('pandas.core.frame', 'DataFrame', display_dataframe)",
- True, False, None, False) in execute_cell_mock.mock_calls
+ assert (
+ call(
+ "from autovizwidget.widget.utils import display_dataframe\nip = get_ipython()\nip.display_formatter"
+ ".ipython_display_formatter.for_type_by_name('pandas.core.frame', 'DataFrame', display_dataframe)",
+ True,
+ False,
+ None,
+ False,
+ )
+ in execute_cell_mock.mock_calls
+ )
@with_setup(_setup, _teardown)
def test_change_language():
kernel._change_language()
- assert call("%%_do_not_call_change_language -l {}\n ".format(LANG_PYTHON),
- True, False, None, False) in execute_cell_mock.mock_calls
+ assert (
+ call(
+ "%%_do_not_call_change_language -l {}\n ".format(LANG_PYTHON),
+ True,
+ False,
+ None,
+ False,
+ )
+ in execute_cell_mock.mock_calls
+ )
@with_setup(_setup, _teardown)
def test_load_magics():
kernel._load_magics_extension()
- assert call("%load_ext sparkmagic.kernels", True, False, None, False) in execute_cell_mock.mock_calls
+ assert (
+ call("%load_ext sparkmagic.kernels", True, False, None, False)
+ in execute_cell_mock.mock_calls
+ )
@with_setup(_setup, _teardown)
def test_delete_session():
kernel._delete_session()
- assert call("%%_do_not_call_delete_session\n ", True, False) in execute_cell_mock.mock_calls
+ assert (
+ call("%%_do_not_call_delete_session\n ", True, False)
+ in execute_cell_mock.mock_calls
+ )
+
-@patch.object(ipykernel.ipkernel.IPythonKernel, 'do_execute')
+@patch.object(ipykernel.ipkernel.IPythonKernel, "do_execute")
@with_setup(_teardown)
def test_execute_cell_for_user_ipykernel5(mock_ipy_execute):
import sys
+
if sys.version_info.major == 2:
from unittest import SkipTest
+
raise SkipTest("Python 3 only")
else:
import asyncio
mock_ipy_execute_result = asyncio.Future()
- mock_ipy_execute_result.set_result({'status': 'OK'})
+ mock_ipy_execute_result.set_result({"status": "OK"})
mock_ipy_execute.return_value = mock_ipy_execute_result
- actual_result = TestSparkKernel()._execute_cell_for_user(code='Foo', silent=True)
+ actual_result = TestSparkKernel()._execute_cell_for_user(code="Foo", silent=True)
- assert {'status': 'OK'} == actual_result
+ assert {"status": "OK"} == actual_result
-@patch.object(ipykernel.ipkernel.IPythonKernel, 'do_execute')
+@patch.object(ipykernel.ipkernel.IPythonKernel, "do_execute")
@with_setup(_teardown)
def test_execute_cell_for_user_ipykernel4(mock_ipy_execute):
- mock_ipy_execute.return_value = {'status': 'OK'}
+ mock_ipy_execute.return_value = {"status": "OK"}
- actual_result = TestSparkKernel()._execute_cell_for_user(code='Foo', silent=True)
+ actual_result = TestSparkKernel()._execute_cell_for_user(code="Foo", silent=True)
- assert {'status': 'OK'} == actual_result
+ assert {"status": "OK"} == actual_result
diff --git a/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py b/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py
index 19b5999fd..5084a4f88 100644
--- a/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py
+++ b/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py
@@ -4,13 +4,25 @@
from nose.tools import with_setup, assert_equals, assert_raises, raises
from sparkmagic.utils.configuration import get_livy_kind
-from sparkmagic.utils.constants import LANGS_SUPPORTED, SESSION_KIND_PYSPARK, SESSION_KIND_SPARK, \
- IDLE_SESSION_STATUS, BUSY_SESSION_STATUS, MIMETYPE_TEXT_PLAIN, EXPECTED_ERROR_MSG
+from sparkmagic.utils.constants import (
+ LANGS_SUPPORTED,
+ SESSION_KIND_PYSPARK,
+ SESSION_KIND_SPARK,
+ IDLE_SESSION_STATUS,
+ BUSY_SESSION_STATUS,
+ MIMETYPE_TEXT_PLAIN,
+ EXPECTED_ERROR_MSG,
+)
from sparkmagic.magics.sparkmagicsbase import SparkMagicBase
-from sparkmagic.livyclientlib.exceptions import DataFrameParseException, BadUserDataException, SparkStatementException
+from sparkmagic.livyclientlib.exceptions import (
+ DataFrameParseException,
+ BadUserDataException,
+ SparkStatementException,
+)
from sparkmagic.livyclientlib.sqlquery import SQLQuery
from sparkmagic.livyclientlib.sparkstorecommand import SparkStoreCommand
+
def _setup():
global magic, session, shell, ipython_display
shell = MagicMock()
@@ -22,6 +34,7 @@ def _setup():
magic.ipython_display = MagicMock()
conf.override_all({})
+
def _teardown():
pass
@@ -119,22 +132,24 @@ def test_print_endpoint_info():
current_session_id = 1
session1 = MagicMock()
session1.id = 1
- session1.get_row_html.return_value = u"""row1 |
"""
+ session1.get_row_html.return_value = """row1 |
"""
session2 = MagicMock()
session2.id = 3
- session2.get_row_html.return_value = u"""row2 |
"""
+ session2.get_row_html.return_value = """row2 |
"""
magic._print_endpoint_info([session2, session1], current_session_id)
- magic.ipython_display.html.assert_called_once_with(u"""
+ magic.ipython_display.html.assert_called_once_with(
+ """
ID | YARN Application ID | Kind | State | Spark UI | Driver log | User | Current session? |
\
row1 |
row2 |
\
-
""")
+
"""
+ )
@with_setup(_setup, _teardown)
def test_print_empty_endpoint_info():
current_session_id = None
magic._print_endpoint_info([], current_session_id)
- magic.ipython_display.html.assert_called_once_with(u'No active sessions.')
+ magic.ipython_display.html.assert_called_once_with("No active sessions.")
@with_setup(_setup, _teardown)
@@ -146,7 +161,10 @@ def test_send_to_spark_should_raise_when_variable_value_is_none():
max_rows = 25000
magic.shell.user_ns[input_variable_name] = None
- magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None)
+ magic.do_send_to_spark(
+ "", input_variable_name, var_type, output_variable_name, max_rows, None
+ )
+
@with_setup(_setup, _teardown)
@raises(BadUserDataException)
@@ -158,7 +176,10 @@ def test_send_to_spark_should_raise_when_type_is_incorrect():
max_rows = 25000
magic.shell.user_ns[input_variable_name] = input_variable_value
- magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None)
+ magic.do_send_to_spark(
+ "", input_variable_name, var_type, output_variable_name, max_rows, None
+ )
+
@with_setup(_setup, _teardown)
def test_send_to_spark_should_print_error_when_str_command_failed():
@@ -169,13 +190,20 @@ def test_send_to_spark_should_print_error_when_str_command_failed():
output_value = "error"
max_rows = 25000
magic.shell.user_ns[input_variable_name] = input_variable_value
- magic.spark_controller.run_command.return_value = (False, output_value, "text/plain")
+ magic.spark_controller.run_command.return_value = (
+ False,
+ output_value,
+ "text/plain",
+ )
- magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None)
+ magic.do_send_to_spark(
+ "", input_variable_name, var_type, output_variable_name, max_rows, None
+ )
magic.ipython_display.send_error.assert_called_once_with(output_value)
assert not magic.ipython_display.write.called
+
@with_setup(_setup, _teardown)
def test_send_to_spark_should_print_error_when_df_command_failed():
input_variable_name = "x_in"
@@ -185,13 +213,20 @@ def test_send_to_spark_should_print_error_when_df_command_failed():
output_value = "error"
max_rows = 25000
magic.shell.user_ns[input_variable_name] = input_variable_value
- magic.spark_controller.run_command.return_value = (False, output_value, "text/plain")
+ magic.spark_controller.run_command.return_value = (
+ False,
+ output_value,
+ "text/plain",
+ )
- magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None)
+ magic.do_send_to_spark(
+ "", input_variable_name, var_type, output_variable_name, max_rows, None
+ )
magic.ipython_display.send_error.assert_called_once_with(output_value)
assert not magic.ipython_display.write.called
+
@with_setup(_setup, _teardown)
def test_send_to_spark_should_name_the_output_variable_the_same_as_input_name_when_custom_name_not_provided():
input_variable_name = "x_in"
@@ -201,13 +236,18 @@ def test_send_to_spark_should_name_the_output_variable_the_same_as_input_name_wh
max_rows = 25000
magic.shell.user_ns[input_variable_name] = input_variable_value
magic.spark_controller.run_command.return_value = (True, output_value, "text/plain")
- expected_message = u'Successfully passed \'{}\' as \'{}\' to Spark kernel'.format(input_variable_name, input_variable_name)
+ expected_message = "Successfully passed '{}' as '{}' to Spark kernel".format(
+ input_variable_name, input_variable_name
+ )
- magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None)
+ magic.do_send_to_spark(
+ "", input_variable_name, var_type, output_variable_name, max_rows, None
+ )
magic.ipython_display.write.assert_called_once_with(expected_message)
assert not magic.ipython_display.send_error.called
+
@with_setup(_setup, _teardown)
def test_send_to_spark_should_write_successfully_when_everything_is_correct():
input_variable_name = "x_in"
@@ -217,42 +257,81 @@ def test_send_to_spark_should_write_successfully_when_everything_is_correct():
var_type = "str"
magic.shell.user_ns[input_variable_name] = input_variable_value
magic.spark_controller.run_command.return_value = (True, output_value, "text/plain")
- expected_message = u'Successfully passed \'{}\' as \'{}\' to Spark kernel'.format(input_variable_name, output_variable_name)
+ expected_message = "Successfully passed '{}' as '{}' to Spark kernel".format(
+ input_variable_name, output_variable_name
+ )
- magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None)
+ magic.do_send_to_spark(
+ "", input_variable_name, var_type, output_variable_name, max_rows, None
+ )
magic.ipython_display.write.assert_called_once_with(expected_message)
assert not magic.ipython_display.send_error.called
+
@with_setup(_setup, _teardown)
def test_spark_execution_without_output_var():
output_var = None
-
- magic.spark_controller.run_command.return_value = (True,'out',MIMETYPE_TEXT_PLAIN)
+
+ magic.spark_controller.run_command.return_value = (True, "out", MIMETYPE_TEXT_PLAIN)
magic.execute_spark("", output_var, None, None, None, session, None)
- magic.ipython_display.write.assert_called_once_with('out')
+ magic.ipython_display.write.assert_called_once_with("out")
assert not magic.spark_controller._spark_store_command.called
- magic.spark_controller.run_command.return_value = (False,'out',MIMETYPE_TEXT_PLAIN)
- assert_raises(SparkStatementException, magic.execute_spark,"", output_var, None, None, None, session, True)
+ magic.spark_controller.run_command.return_value = (
+ False,
+ "out",
+ MIMETYPE_TEXT_PLAIN,
+ )
+ assert_raises(
+ SparkStatementException,
+ magic.execute_spark,
+ "",
+ output_var,
+ None,
+ None,
+ None,
+ session,
+ True,
+ )
assert not magic.spark_controller._spark_store_command.called
+
@with_setup(_setup, _teardown)
def test_spark_execution_with_output_var():
mockSparkCommand = MagicMock()
magic._spark_store_command = MagicMock(return_value=mockSparkCommand)
output_var = "var_name"
- df = 'df'
+ df = "df"
- magic.spark_controller.run_command.side_effect = [(True,'out',MIMETYPE_TEXT_PLAIN), df]
+ magic.spark_controller.run_command.side_effect = [
+ (True, "out", MIMETYPE_TEXT_PLAIN),
+ df,
+ ]
magic.execute_spark("", output_var, None, None, None, session, True)
- magic.ipython_display.write.assert_called_once_with('out')
- magic._spark_store_command.assert_called_once_with(output_var, None, None, None, True)
+ magic.ipython_display.write.assert_called_once_with("out")
+ magic._spark_store_command.assert_called_once_with(
+ output_var, None, None, None, True
+ )
assert shell.user_ns[output_var] == df
magic.spark_controller.run_command.side_effect = None
- magic.spark_controller.run_command.return_value = (False,'out',MIMETYPE_TEXT_PLAIN)
- assert_raises(SparkStatementException, magic.execute_spark,"", output_var, None, None, None, session, True)
+ magic.spark_controller.run_command.return_value = (
+ False,
+ "out",
+ MIMETYPE_TEXT_PLAIN,
+ )
+ assert_raises(
+ SparkStatementException,
+ magic.execute_spark,
+ "",
+ output_var,
+ None,
+ None,
+ None,
+ session,
+ True,
+ )
@with_setup(_setup, _teardown)
@@ -261,34 +340,75 @@ def test_spark_exception_with_output_var():
magic._spark_store_command = MagicMock(return_value=mockSparkCommand)
exception = BadUserDataException("Ka-boom!")
output_var = "var_name"
- df = 'df'
-
- magic.spark_controller.run_command.side_effect = [(True,'out',MIMETYPE_TEXT_PLAIN), exception]
- assert_raises(BadUserDataException, magic.execute_spark,"", output_var, None, None, None, session, True)
- magic.ipython_display.write.assert_called_once_with('out')
- magic._spark_store_command.assert_called_once_with(output_var, None, None, None, True)
+ df = "df"
+
+ magic.spark_controller.run_command.side_effect = [
+ (True, "out", MIMETYPE_TEXT_PLAIN),
+ exception,
+ ]
+ assert_raises(
+ BadUserDataException,
+ magic.execute_spark,
+ "",
+ output_var,
+ None,
+ None,
+ None,
+ session,
+ True,
+ )
+ magic.ipython_display.write.assert_called_once_with("out")
+ magic._spark_store_command.assert_called_once_with(
+ output_var, None, None, None, True
+ )
assert shell.user_ns == {}
+
@with_setup(_setup, _teardown)
def test_spark_statement_exception():
mockSparkCommand = MagicMock()
magic._spark_store_command = MagicMock(return_value=mockSparkCommand)
exception = BadUserDataException("Ka-boom!")
- magic.spark_controller.run_command.side_effect = [(False, 'out', "text/plain"), exception]
- assert_raises(SparkStatementException, magic.execute_spark,"", None, None, None, None, session, True)
+ magic.spark_controller.run_command.side_effect = [
+ (False, "out", "text/plain"),
+ exception,
+ ]
+ assert_raises(
+ SparkStatementException,
+ magic.execute_spark,
+ "",
+ None,
+ None,
+ None,
+ None,
+ session,
+ True,
+ )
magic.spark_controller.cleanup.assert_not_called()
+
@with_setup(_setup, _teardown)
def test_spark_statement_exception_shutdowns_livy_session():
- conf.override_all({
- "shutdown_session_on_spark_statement_errors": True
- })
+ conf.override_all({"shutdown_session_on_spark_statement_errors": True})
mockSparkCommand = MagicMock()
magic._spark_store_command = MagicMock(return_value=mockSparkCommand)
exception = BadUserDataException("Ka-boom!")
- magic.spark_controller.run_command.side_effect = [(False, 'out', "text/plain"), exception]
- assert_raises(SparkStatementException, magic.execute_spark,"", None, None, None, None, session, True)
+ magic.spark_controller.run_command.side_effect = [
+ (False, "out", "text/plain"),
+ exception,
+ ]
+ assert_raises(
+ SparkStatementException,
+ magic.execute_spark,
+ "",
+ None,
+ None,
+ None,
+ None,
+ session,
+ True,
+ )
magic.spark_controller.cleanup.assert_called_once()
diff --git a/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py b/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py
index c011c6aae..d73bb7b27 100644
--- a/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py
+++ b/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py
@@ -12,17 +12,20 @@
backup_conf_defaults = None
+
def _setup():
global backup_conf_defaults
backup_conf_defaults = {
- 'samplemethod' : conf.default_samplemethod(),
- 'maxrows': conf.default_maxrows(),
- 'samplefraction': conf.default_samplefraction()
+ "samplemethod": conf.default_samplemethod(),
+ "maxrows": conf.default_maxrows(),
+ "samplefraction": conf.default_samplefraction(),
}
+
def _teardown():
conf.override_all(backup_conf_defaults)
+
@with_setup(_setup, _teardown)
def test_to_command_pyspark():
variable_name = "var_name"
@@ -54,7 +57,9 @@ def test_to_command_r():
def test_to_command_invalid():
variable_name = "var_name"
sparkcommand = SparkStoreCommand(variable_name)
- assert_raises(BadUserDataException, sparkcommand.to_command, "invalid", variable_name)
+ assert_raises(
+ BadUserDataException, sparkcommand.to_command, "invalid", variable_name
+ )
@with_setup(_setup, _teardown)
@@ -63,7 +68,9 @@ def test_sparkstorecommand_initializes():
samplemethod = "take"
maxrows = 120
samplefraction = 0.6
- sparkcommand = SparkStoreCommand(variable_name, samplemethod, maxrows, samplefraction)
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod, maxrows, samplefraction
+ )
assert_equals(sparkcommand.samplemethod, samplemethod)
assert_equals(sparkcommand.maxrows, maxrows)
assert_equals(sparkcommand.samplefraction, samplefraction)
@@ -79,137 +86,242 @@ def test_sparkstorecommand_loads_defaults():
conf.override_all(defaults)
variable_name = "var_name"
sparkcommand = SparkStoreCommand(variable_name)
- assert_equals(sparkcommand.samplemethod, defaults[conf.default_samplemethod.__name__])
+ assert_equals(
+ sparkcommand.samplemethod, defaults[conf.default_samplemethod.__name__]
+ )
assert_equals(sparkcommand.maxrows, defaults[conf.default_maxrows.__name__])
- assert_equals(sparkcommand.samplefraction, defaults[conf.default_samplefraction.__name__])
+ assert_equals(
+ sparkcommand.samplefraction, defaults[conf.default_samplefraction.__name__]
+ )
@with_setup(_setup, _teardown)
def test_pyspark_livy_sampling_options():
variable_name = "var_name"
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=120)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=-1)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME)))
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=120)
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.25, maxrows=-1
+ )
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): "
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234
+ )
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): "
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234
+ )
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): "
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
@with_setup(_setup, _teardown)
def test_scala_livy_sampling_options():
variable_name = "abc"
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=100)
- assert_equals(sparkcommand._scala_command(variable_name),
- Command('{}.toJSON.take(100).foreach(println)'.format(variable_name)))
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=100)
+ assert_equals(
+ sparkcommand._scala_command(variable_name),
+ Command("{}.toJSON.take(100).foreach(println)".format(variable_name)),
+ )
+
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sparkcommand._scala_command(variable_name),
+ Command("{}.toJSON.collect.foreach(println)".format(variable_name)),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.25, maxrows=-1
+ )
+ assert_equals(
+ sparkcommand._scala_command(variable_name),
+ Command(
+ "{}.toJSON.sample(false, 0.25).collect.foreach(println)".format(
+ variable_name
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234
+ )
+ assert_equals(
+ sparkcommand._scala_command(variable_name),
+ Command(
+ "{}.toJSON.sample(false, 0.33).take(3234).foreach(println)".format(
+ variable_name
+ )
+ ),
+ )
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=-1)
- assert_equals(sparkcommand._scala_command(variable_name),
- Command('{}.toJSON.collect.foreach(println)'.format(variable_name)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sparkcommand._scala_command(variable_name),
- Command('{}.toJSON.sample(false, 0.25).collect.foreach(println)'.format(variable_name)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sparkcommand._scala_command(variable_name),
- Command('{}.toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(variable_name)))
-
sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100)
- assert_equals(sparkcommand._scala_command(variable_name),
- Command('{}.toJSON.take(100).foreach(println)'.format(variable_name)))
+ assert_equals(
+ sparkcommand._scala_command(variable_name),
+ Command("{}.toJSON.take(100).foreach(println)".format(variable_name)),
+ )
+
@with_setup(_setup, _teardown)
def test_r_livy_sampling_options():
variable_name = "abc"
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=100)
-
- assert_equals(sparkcommand._r_command(variable_name),
- Command('for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=-1)
- assert_equals(sparkcommand._r_command(variable_name),
- Command('for ({} in (jsonlite::toJSON(collect({})))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sparkcommand._r_command(variable_name),
- Command('for ({} in (jsonlite::toJSON(collect(sample({}, FALSE, 0.25))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME)))
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sparkcommand._r_command(variable_name),
- Command('for ({} in (jsonlite::toJSON(take(sample({}, FALSE, 0.33),3234)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME)))
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=100)
+
+ assert_equals(
+ sparkcommand._r_command(variable_name),
+ Command(
+ "for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sparkcommand._r_command(variable_name),
+ Command(
+ "for ({} in (jsonlite::toJSON(collect({})))) {{cat({})}}".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.25, maxrows=-1
+ )
+ assert_equals(
+ sparkcommand._r_command(variable_name),
+ Command(
+ "for ({} in (jsonlite::toJSON(collect(sample({}, FALSE, 0.25))))) {{cat({})}}".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sparkcommand = SparkStoreCommand(
+ variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234
+ )
+ assert_equals(
+ sparkcommand._r_command(variable_name),
+ Command(
+ "for ({} in (jsonlite::toJSON(take(sample({}, FALSE, 0.33),3234)))) {{cat({})}}".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100)
- assert_equals(sparkcommand._r_command(variable_name),
- Command('for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
+ assert_equals(
+ sparkcommand._r_command(variable_name),
+ Command(
+ "for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
@with_setup(_setup, _teardown)
def test_execute_code():
spark_events = MagicMock()
variable_name = "abc"
- sparkcommand = SparkStoreCommand(variable_name, "take", 100, 0.2, spark_events=spark_events)
+ sparkcommand = SparkStoreCommand(
+ variable_name, "take", 100, 0.2, spark_events=spark_events
+ )
sparkcommand.to_command = MagicMock(return_value=MagicMock())
result = """{"z":100, "nullv":null, "y":50}
{"z":25, "nullv":null, "y":10}"""
- sparkcommand.to_command.return_value.execute = MagicMock(return_value=(True, result, MIMETYPE_TEXT_PLAIN))
+ sparkcommand.to_command.return_value.execute = MagicMock(
+ return_value=(True, result, MIMETYPE_TEXT_PLAIN)
+ )
session = MagicMock()
session.kind = "pyspark"
result = sparkcommand.execute(session)
-
+
sparkcommand.to_command.assert_called_once_with(session.kind, variable_name)
sparkcommand.to_command.return_value.execute.assert_called_once_with(session)
@with_setup(_setup, _teardown)
def test_unicode():
- variable_name = u"collect 'รจ'"
-
- sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=120)
- assert_equals(sparkcommand._pyspark_command(variable_name),
- Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, variable_name,
- LONG_RANDOM_VARIABLE_NAME)))
- assert_equals(sparkcommand._scala_command(variable_name),
- Command(u'{}.toJSON.take(120).foreach(println)'.format(variable_name)))
-
+ variable_name = "collect 'รจ'"
+
+ sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=120)
+ assert_equals(
+ sparkcommand._pyspark_command(variable_name),
+ Command(
+ "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+ assert_equals(
+ sparkcommand._scala_command(variable_name),
+ Command("{}.toJSON.take(120).foreach(println)".format(variable_name)),
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_sqlquery.py b/sparkmagic/sparkmagic/tests/test_sqlquery.py
index cae49af8d..f685f0042 100644
--- a/sparkmagic/sparkmagic/tests/test_sqlquery.py
+++ b/sparkmagic/sparkmagic/tests/test_sqlquery.py
@@ -13,7 +13,7 @@
def _setup():
pass
-
+
def _teardown():
pass
@@ -54,7 +54,9 @@ def test_sqlquery_loads_defaults():
assert_equals(sqlquery.query, query)
assert_equals(sqlquery.samplemethod, defaults[conf.default_samplemethod.__name__])
assert_equals(sqlquery.maxrows, defaults[conf.default_maxrows.__name__])
- assert_equals(sqlquery.samplefraction, defaults[conf.default_samplefraction.__name__])
+ assert_equals(
+ sqlquery.samplefraction, defaults[conf.default_samplefraction.__name__]
+ )
@with_setup(_setup, _teardown)
@@ -69,101 +71,181 @@ def test_sqlquery_rejects_bad_data():
def test_pyspark_livy_sql_options():
query = "abc"
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=120)
- assert_equals(sqlquery._pyspark_command("sqlContext"),
- Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1)
- assert_equals(sqlquery._pyspark_command("sqlContext"),
- Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sqlquery._pyspark_command("sqlContext"),
- Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._pyspark_command("sqlContext"),
- Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._pyspark_command("spark"),
- Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME)))
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=120)
+ assert_equals(
+ sqlquery._pyspark_command("sqlContext"),
+ Command(
+ 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sqlquery._pyspark_command("sqlContext"),
+ Command(
+ 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1)
+ assert_equals(
+ sqlquery._pyspark_command("sqlContext"),
+ Command(
+ 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): '
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._pyspark_command("sqlContext"),
+ Command(
+ 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._pyspark_command("spark"),
+ Command(
+ 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
@with_setup(_setup, _teardown)
def test_scala_livy_sql_options():
query = "abc"
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=100)
- assert_equals(sqlquery._scala_command("sqlContext"),
- Command('sqlContext.sql("""{}""").toJSON.take(100).foreach(println)'.format(query)))
-
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1)
- assert_equals(sqlquery._scala_command("sqlContext"),
- Command('sqlContext.sql("""{}""").toJSON.collect.foreach(println)'.format(query)))
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=100)
+ assert_equals(
+ sqlquery._scala_command("sqlContext"),
+ Command(
+ 'sqlContext.sql("""{}""").toJSON.take(100).foreach(println)'.format(query)
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sqlquery._scala_command("sqlContext"),
+ Command(
+ 'sqlContext.sql("""{}""").toJSON.collect.foreach(println)'.format(query)
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1)
+ assert_equals(
+ sqlquery._scala_command("sqlContext"),
+ Command(
+ 'sqlContext.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format(
+ query
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._scala_command("sqlContext"),
+ Command(
+ 'sqlContext.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(
+ query
+ )
+ ),
+ )
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sqlquery._scala_command("sqlContext"),
- Command('sqlContext.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format(query)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._scala_command("sqlContext"),
- Command('sqlContext.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(query)))
@with_setup(_setup, _teardown)
def test_r_livy_sql_options_spark():
- query = "abc"
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=100)
- sqlContext = "sqlContext"
-
- assert_equals(sqlquery._r_command(sqlContext),
- Command('for ({} in (jsonlite:::toJSON(take(sql({}, "{}"),100)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1)
- assert_equals(sqlquery._r_command(sqlContext),
- Command('for ({} in (jsonlite:::toJSON(collect(sql({}, "{}"))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sqlquery._r_command(sqlContext),
- Command('for ({} in (jsonlite:::toJSON(collect(sample(sql({}, "{}"), FALSE, 0.25))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._r_command(sqlContext),
- Command('for ({} in (jsonlite:::toJSON(take(sample(sql({}, "{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME)))
+ query = "abc"
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=100)
+ sqlContext = "sqlContext"
+
+ assert_equals(
+ sqlquery._r_command(sqlContext),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(take(sql({}, "{}"),100)))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sqlquery._r_command(sqlContext),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(collect(sql({}, "{}"))))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1)
+ assert_equals(
+ sqlquery._r_command(sqlContext),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(collect(sample(sql({}, "{}"), FALSE, 0.25))))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._r_command(sqlContext),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(take(sample(sql({}, "{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
@with_setup(_setup, _teardown)
def test_execute_sql():
spark_events = MagicMock()
- sqlquery = SQLQuery("HERE IS THE QUERY", "take", 100, 0.2, spark_events=spark_events)
+ sqlquery = SQLQuery(
+ "HERE IS THE QUERY", "take", 100, 0.2, spark_events=spark_events
+ )
sqlquery.to_command = MagicMock(return_value=MagicMock())
result = """{"z":100, "nullv":null, "y":50}
{"z":25, "nullv":null, "y":10}"""
- sqlquery.to_command.return_value.execute = MagicMock(return_value=(True, result, MIMETYPE_TEXT_PLAIN))
- result_data = pd.DataFrame([{'z': 100, "nullv": None, 'y': 50}, {'z':25, "nullv":None, 'y':10}], columns=['z', "nullv", 'y'])
+ sqlquery.to_command.return_value.execute = MagicMock(
+ return_value=(True, result, MIMETYPE_TEXT_PLAIN)
+ )
+ result_data = pd.DataFrame(
+ [{"z": 100, "nullv": None, "y": 50}, {"z": 25, "nullv": None, "y": 10}],
+ columns=["z", "nullv", "y"],
+ )
session = MagicMock()
session.kind = "pyspark"
result = sqlquery.execute(session)
assert_frame_equal(result, result_data)
sqlquery.to_command.return_value.execute.assert_called_once_with(session)
- spark_events.emit_sql_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, sqlquery.guid,
- 'take', 100, 0.2)
- spark_events.emit_sql_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, sqlquery.guid,
- sqlquery.to_command.return_value.guid,
- True, '','')
+ spark_events.emit_sql_execution_start_event.assert_called_once_with(
+ session.guid, session.kind, session.id, sqlquery.guid, "take", 100, 0.2
+ )
+ spark_events.emit_sql_execution_end_event.assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ sqlquery.guid,
+ sqlquery.to_command.return_value.guid,
+ True,
+ "",
+ "",
+ )
@with_setup(_setup, _teardown)
@@ -177,19 +259,34 @@ def test_execute_sql_no_results():
result1 = ""
result_data = pd.DataFrame([])
session = MagicMock()
- sqlquery.to_command.return_value.execute.return_value = (True, result1, MIMETYPE_TEXT_PLAIN)
+ sqlquery.to_command.return_value.execute.return_value = (
+ True,
+ result1,
+ MIMETYPE_TEXT_PLAIN,
+ )
session.kind = "spark"
result = sqlquery.execute(session)
assert_frame_equal(result, result_data)
sqlquery.to_command.return_value.execute.assert_called_once_with(session)
- spark_events.emit_sql_execution_start_event.assert_called_once_with(session.guid, session.kind,
- session.id, sqlquery.guid,
- sqlquery.samplemethod, sqlquery.maxrows,
- sqlquery.samplefraction)
- spark_events.emit_sql_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, sqlquery.guid,
- sqlquery.to_command.return_value.guid,
- True, "", "")
+ spark_events.emit_sql_execution_start_event.assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ sqlquery.guid,
+ sqlquery.samplemethod,
+ sqlquery.maxrows,
+ sqlquery.samplefraction,
+ )
+ spark_events.emit_sql_execution_end_event.assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ sqlquery.guid,
+ sqlquery.to_command.return_value.guid,
+ True,
+ "",
+ "",
+ )
@with_setup(_setup, _teardown)
@@ -197,7 +294,7 @@ def test_execute_sql_failure_emits_event():
spark_events = MagicMock()
sqlquery = SQLQuery("HERE IS THE QUERY", "take", 100, 0.2, spark_events)
sqlquery.to_command = MagicMock()
- sqlquery.to_command.return_value.execute = MagicMock(side_effect=ValueError('yo'))
+ sqlquery.to_command.return_value.execute = MagicMock(side_effect=ValueError("yo"))
session = MagicMock()
session.kind = "pyspark"
try:
@@ -205,93 +302,170 @@ def test_execute_sql_failure_emits_event():
assert False
except ValueError:
sqlquery.to_command.return_value.execute.assert_called_once_with(session)
- spark_events.emit_sql_execution_end_event.assert_called_once_with(session.guid, session.kind,
- session.id, sqlquery.guid,
- sqlquery.to_command.return_value.guid,
- False, 'ValueError', 'yo')
+ spark_events.emit_sql_execution_end_event.assert_called_once_with(
+ session.guid,
+ session.kind,
+ session.id,
+ sqlquery.guid,
+ sqlquery.to_command.return_value.guid,
+ False,
+ "ValueError",
+ "yo",
+ )
@with_setup(_setup, _teardown)
def test_unicode_sql():
- query = u"SELECT 'รจ'"
+ query = "SELECT 'รจ'"
longvar = LONG_RANDOM_VARIABLE_NAME
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=120)
- assert_equals(sqlquery._pyspark_command("spark"),
- Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\
- .format(longvar, query,
- longvar)))
- assert_equals(sqlquery._scala_command("spark"),
- Command(u'spark.sql("""{}""").toJSON.take(120).foreach(println)'.format(query)))
- assert_equals(sqlquery._r_command("spark"),
- Command(u'for ({} in (jsonlite:::toJSON(take(sql("{}"),120)))) {{cat({})}}'.format(longvar, query, longvar)))
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=120)
+ assert_equals(
+ sqlquery._pyspark_command("spark"),
+ Command(
+ 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'.format(
+ longvar, query, longvar
+ )
+ ),
+ )
+ assert_equals(
+ sqlquery._scala_command("spark"),
+ Command('spark.sql("""{}""").toJSON.take(120).foreach(println)'.format(query)),
+ )
+ assert_equals(
+ sqlquery._r_command("spark"),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(take(sql("{}"),120)))) {{cat({})}}'.format(
+ longvar, query, longvar
+ )
+ ),
+ )
+
@with_setup(_setup, _teardown)
def test_pyspark_livy_sql_options_spark2():
- query = "abc"
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=120)
-
- assert_equals(sqlquery._pyspark_command("spark"),
- Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1)
- assert_equals(sqlquery._pyspark_command("spark"),
- Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sqlquery._pyspark_command("spark"),
- Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._pyspark_command("spark"),
- Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
- u'print({})'\
- .format(LONG_RANDOM_VARIABLE_NAME, query,
- LONG_RANDOM_VARIABLE_NAME)))
+ query = "abc"
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=120)
+
+ assert_equals(
+ sqlquery._pyspark_command("spark"),
+ Command(
+ 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sqlquery._pyspark_command("spark"),
+ Command(
+ 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1)
+ assert_equals(
+ sqlquery._pyspark_command("spark"),
+ Command(
+ 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): '
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._pyspark_command("spark"),
+ Command(
+ 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): '
+ "print({})".format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
@with_setup(_setup, _teardown)
def test_scala_livy_sql_options_spark2():
- query = "abc"
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=100)
-
- assert_equals(sqlquery._scala_command("spark"),
- Command('spark.sql("""{}""").toJSON.take(100).foreach(println)'.format(query)))
-
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1)
- assert_equals(sqlquery._scala_command("spark"),
- Command('spark.sql("""{}""").toJSON.collect.foreach(println)'.format(query)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sqlquery._scala_command("spark"),
- Command('spark.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format(query)))
+ query = "abc"
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=100)
+
+ assert_equals(
+ sqlquery._scala_command("spark"),
+ Command('spark.sql("""{}""").toJSON.take(100).foreach(println)'.format(query)),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sqlquery._scala_command("spark"),
+ Command('spark.sql("""{}""").toJSON.collect.foreach(println)'.format(query)),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1)
+ assert_equals(
+ sqlquery._scala_command("spark"),
+ Command(
+ 'spark.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format(
+ query
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._scala_command("spark"),
+ Command(
+ 'spark.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(
+ query
+ )
+ ),
+ )
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._scala_command("spark"),
- Command('spark.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(query)))
@with_setup(_setup, _teardown)
def test_r_livy_sql_options_spark2():
- query = "abc"
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=100)
-
- assert_equals(sqlquery._r_command("spark"),
- Command('for ({} in (jsonlite:::toJSON(take(sql("{}"),100)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1)
- assert_equals(sqlquery._r_command("spark"),
- Command('for ({} in (jsonlite:::toJSON(collect(sql("{}"))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1)
- assert_equals(sqlquery._r_command("spark"),
- Command('for ({} in (jsonlite:::toJSON(collect(sample(sql("{}"), FALSE, 0.25))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME)))
-
- sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234)
- assert_equals(sqlquery._r_command("spark"),
- Command('for ({} in (jsonlite:::toJSON(take(sample(sql("{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME)))
+ query = "abc"
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=100)
+
+ assert_equals(
+ sqlquery._r_command("spark"),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(take(sql("{}"),100)))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1)
+ assert_equals(
+ sqlquery._r_command("spark"),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(collect(sql("{}"))))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1)
+ assert_equals(
+ sqlquery._r_command("spark"),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(collect(sample(sql("{}"), FALSE, 0.25))))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
+
+ sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234)
+ assert_equals(
+ sqlquery._r_command("spark"),
+ Command(
+ 'for ({} in (jsonlite:::toJSON(take(sample(sql("{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format(
+ LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME
+ )
+ ),
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_usercodeparser.py b/sparkmagic/sparkmagic/tests/test_usercodeparser.py
index b1e54a67b..10cb776bb 100644
--- a/sparkmagic/sparkmagic/tests/test_usercodeparser.py
+++ b/sparkmagic/sparkmagic/tests/test_usercodeparser.py
@@ -8,67 +8,72 @@
def test_empty_string():
parser = UserCodeParser()
- assert_equals(u"", parser.get_code_to_run(u""))
+ assert_equals("", parser.get_code_to_run(""))
def test_spark_code():
parser = UserCodeParser()
- cell = u"my code\nand more"
+ cell = "my code\nand more"
- assert_equals(u"%%spark\nmy code\nand more", parser.get_code_to_run(cell))
+ assert_equals("%%spark\nmy code\nand more", parser.get_code_to_run(cell))
def test_local_single():
parser = UserCodeParser()
- cell = u"""%local
+ cell = """%local
hi
hi
hi"""
- assert_equals(u"hi\nhi\nhi", parser.get_code_to_run(cell))
+ assert_equals("hi\nhi\nhi", parser.get_code_to_run(cell))
def test_local_double():
parser = UserCodeParser()
- cell = u"""%%local
+ cell = """%%local
hi
hi
hi"""
- assert_equals(u"hi\nhi\nhi", parser.get_code_to_run(cell))
+ assert_equals("hi\nhi\nhi", parser.get_code_to_run(cell))
def test_our_line_magics():
parser = UserCodeParser()
magic_name = KernelMagics.info.__name__
- cell = u"%{}".format(magic_name)
+ cell = "%{}".format(magic_name)
- assert_equals(u"%%{}\n ".format(magic_name), parser.get_code_to_run(cell))
+ assert_equals("%%{}\n ".format(magic_name), parser.get_code_to_run(cell))
def test_our_line_magics_with_content():
parser = UserCodeParser()
magic_name = KernelMagics.info.__name__
- cell = u"""%{}
+ cell = """%{}
my content
-more content""".format(magic_name)
+more content""".format(
+ magic_name
+ )
- assert_equals(u"%%{}\nmy content\nmore content\n ".format(magic_name), parser.get_code_to_run(cell))
+ assert_equals(
+ "%%{}\nmy content\nmore content\n ".format(magic_name),
+ parser.get_code_to_run(cell),
+ )
def test_other_cell_magic():
parser = UserCodeParser()
- cell = u"""%%magic
+ cell = """%%magic
hi
hi
hi"""
- assert_equals(u"{}".format(cell), parser.get_code_to_run(cell))
+ assert_equals("{}".format(cell), parser.get_code_to_run(cell))
def test_other_line_magic():
parser = UserCodeParser()
- cell = u"""%magic
+ cell = """%magic
hi
hi
hi"""
@@ -78,36 +83,46 @@ def test_other_line_magic():
def test_scala_code():
parser = UserCodeParser()
- cell = u"""/* Place the cursor in the cell and press SHIFT + ENTER to run */
+ cell = """/* Place the cursor in the cell and press SHIFT + ENTER to run */
val fruits = sc.textFile("wasb:///example/data/fruits.txt")
val yellowThings = sc.textFile("wasb:///example/data/yellowthings.txt")"""
- assert_equals(u"%%spark\n{}".format(cell), parser.get_code_to_run(cell))
+ assert_equals("%%spark\n{}".format(cell), parser.get_code_to_run(cell))
def test_unicode():
parser = UserCodeParser()
- cell = u"print 'รจ๐๐๐๐'"
+ cell = "print 'รจ๐๐๐๐'"
- assert_equals(u"%%spark\n{}".format(cell), parser.get_code_to_run(cell))
+ assert_equals("%%spark\n{}".format(cell), parser.get_code_to_run(cell))
def test_unicode_in_magics():
parser = UserCodeParser()
magic_name = KernelMagics.info.__name__
- cell = u"""%{}
+ cell = """%{}
my content รจ๐
-more content""".format(magic_name)
+more content""".format(
+ magic_name
+ )
- assert_equals(u"%%{}\nmy content รจ๐\nmore content\n ".format(magic_name), parser.get_code_to_run(cell))
+ assert_equals(
+ "%%{}\nmy content รจ๐\nmore content\n ".format(magic_name),
+ parser.get_code_to_run(cell),
+ )
def test_unicode_in_double_magics():
parser = UserCodeParser()
magic_name = KernelMagics.info.__name__
- cell = u"""%%{}
+ cell = """%%{}
my content รจ๐
-more content""".format(magic_name)
-
- assert_equals(u"%%{}\nmy content รจ๐\nmore content\n ".format(magic_name), parser.get_code_to_run(cell))
+more content""".format(
+ magic_name
+ )
+
+ assert_equals(
+ "%%{}\nmy content รจ๐\nmore content\n ".format(magic_name),
+ parser.get_code_to_run(cell),
+ )
diff --git a/sparkmagic/sparkmagic/tests/test_utils.py b/sparkmagic/sparkmagic/tests/test_utils.py
index 72779aedf..71fe67883 100644
--- a/sparkmagic/sparkmagic/tests/test_utils.py
+++ b/sparkmagic/sparkmagic/tests/test_utils.py
@@ -8,22 +8,32 @@
from sparkmagic.livyclientlib.exceptions import BadUserDataException
from sparkmagic.utils.utils import parse_argstring_or_throw, records_to_dataframe
from sparkmagic.utils.constants import SESSION_KIND_PYSPARK
-from sparkmagic.utils.dataframe_parser import DataframeHtmlParser, cell_contains_dataframe, CellComponentType, cell_components_iter, CellOutputHtmlParser
+from sparkmagic.utils.dataframe_parser import (
+ DataframeHtmlParser,
+ cell_contains_dataframe,
+ CellComponentType,
+ cell_components_iter,
+ CellOutputHtmlParser,
+)
import unittest
def test_parse_argstring_or_throw():
- parse_argstring = MagicMock(side_effect=UsageError('OOGABOOGABOOGA'))
+ parse_argstring = MagicMock(side_effect=UsageError("OOGABOOGABOOGA"))
try:
- parse_argstring_or_throw(MagicMock(), MagicMock(), parse_argstring=parse_argstring)
+ parse_argstring_or_throw(
+ MagicMock(), MagicMock(), parse_argstring=parse_argstring
+ )
assert False
except BadUserDataException as e:
assert_equals(str(e), str(parse_argstring.side_effect))
- parse_argstring = MagicMock(side_effect=ValueError('AN UNKNOWN ERROR HAPPENED'))
+ parse_argstring = MagicMock(side_effect=ValueError("AN UNKNOWN ERROR HAPPENED"))
try:
- parse_argstring_or_throw(MagicMock(), MagicMock(), parse_argstring=parse_argstring)
+ parse_argstring_or_throw(
+ MagicMock(), MagicMock(), parse_argstring=parse_argstring
+ )
assert False
except ValueError as e:
assert_is(e, parse_argstring.side_effect)
@@ -32,36 +42,51 @@ def test_parse_argstring_or_throw():
def test_records_to_dataframe_missing_value_first():
result = """{"z":100, "y":50}
{"z":25, "nullv":1.0, "y":10}"""
-
+
df = records_to_dataframe(result, SESSION_KIND_PYSPARK, True)
- expected = pd.DataFrame([{'z': 100, "nullv": None, 'y': 50}, {'z':25, "nullv":1, 'y':10}], columns=['z', "nullv", 'y'])
+ expected = pd.DataFrame(
+ [{"z": 100, "nullv": None, "y": 50}, {"z": 25, "nullv": 1, "y": 10}],
+ columns=["z", "nullv", "y"],
+ )
assert_frame_equal(expected, df)
def test_records_to_dataframe_coercing():
result = """{"z":"100", "y":"2016-01-01"}
{"z":"25", "y":"2016-01-01"}"""
-
+
df = records_to_dataframe(result, SESSION_KIND_PYSPARK, True)
- expected = pd.DataFrame([{'z': 100, 'y': np.datetime64("2016-01-01")}, {'z':25, 'y':np.datetime64("2016-01-01")}], columns=['z', 'y'])
+ expected = pd.DataFrame(
+ [
+ {"z": 100, "y": np.datetime64("2016-01-01")},
+ {"z": 25, "y": np.datetime64("2016-01-01")},
+ ],
+ columns=["z", "y"],
+ )
assert_frame_equal(expected, df)
def test_records_to_dataframe_no_coercing():
result = """{"z":"100", "y":"2016-01-01"}
{"z":"25", "y":"2016-01-01"}"""
-
+
df = records_to_dataframe(result, SESSION_KIND_PYSPARK, False)
- expected = pd.DataFrame([{'z': "100", 'y': "2016-01-01"}, {'z':"25", 'y':"2016-01-01"}], columns=['z', 'y'])
+ expected = pd.DataFrame(
+ [{"z": "100", "y": "2016-01-01"}, {"z": "25", "y": "2016-01-01"}],
+ columns=["z", "y"],
+ )
assert_frame_equal(expected, df)
def test_records_to_dataframe_missing_value_later():
result = """{"z":25, "nullv":1.0, "y":10}
{"z":100, "y":50}"""
-
+
df = records_to_dataframe(result, SESSION_KIND_PYSPARK, True)
- expected = pd.DataFrame([{'z':25, "nullv":1, 'y':10}, {'z': 100, "nullv": None, 'y': 50}], columns=['z', "nullv", 'y'])
+ expected = pd.DataFrame(
+ [{"z": 25, "nullv": 1, "y": 10}, {"z": 100, "nullv": None, "y": 50}],
+ columns=["z", "nullv", "y"],
+ )
assert_frame_equal(expected, df)
@@ -80,9 +105,9 @@ def test_dataframe_component(self):
"""
dc = DataframeHtmlParser(cell)
rows = dc.row_iter()
- self.assertDictEqual(next(rows), {'id': '1', 'animal': 'bat'})
- self.assertDictEqual(next(rows), {'id': '2', 'animal': 'mouse'})
- self.assertDictEqual(next(rows), {'id': '3', 'animal': 'horse'})
+ self.assertDictEqual(next(rows), {"id": "1", "animal": "bat"})
+ self.assertDictEqual(next(rows), {"id": "2", "animal": "mouse"})
+ self.assertDictEqual(next(rows), {"id": "3", "animal": "horse"})
with self.assertRaises(StopIteration):
next(rows)
@@ -93,7 +118,7 @@ def test_dataframe_component(self):
+---+------+
"""
dc = DataframeHtmlParser(cell)
- rows = dc.row_iter()
+ rows = dc.row_iter()
with self.assertRaises(StopIteration):
next(rows)
@@ -109,13 +134,13 @@ def test_dataframe_component(self):
Only showing the last 20 rows
"""
dc = DataframeHtmlParser(cell)
- rows = dc.row_iter()
- self.assertDictEqual(next(rows), {'id': '1', 'animal': 'bat'})
+ rows = dc.row_iter()
+ self.assertDictEqual(next(rows), {"id": "1", "animal": "bat"})
with self.assertRaises(ValueError):
next(rows)
def test_dataframe_parsing(self):
-
+
cell = """+---+------+
| id|animal|
+---+------+
@@ -139,8 +164,9 @@ def test_dataframe_parsing(self):
Only showing the last 20 rows
"""
- self.assertTrue(cell_contains_dataframe(cell), "Matches with leading whitespaces")
-
+ self.assertTrue(
+ cell_contains_dataframe(cell), "Matches with leading whitespaces"
+ )
cell = """
+---+------+
@@ -160,7 +186,9 @@ def test_dataframe_parsing(self):
+---+------+
"""
- self.assertTrue(cell_contains_dataframe(cell), "Cell contains multiple dataframes")
+ self.assertTrue(
+ cell_contains_dataframe(cell), "Cell contains multiple dataframes"
+ )
cell = """
+---+------+
@@ -190,7 +218,7 @@ def test_dataframe_parsing(self):
Only showing the last 20 rows
"""
self.assertFalse(cell_contains_dataframe(cell), "Footer contains a /")
-
+
def test_cell_components_iter(self):
cell = """+---+------+
| id|animal|
@@ -203,7 +231,7 @@ def test_cell_components_iter(self):
Only showing the last 20 rows
"""
df_spans = cell_components_iter(cell)
-
+
self.assertEqual(next(df_spans), (CellComponentType.DF, 0, 90))
self.assertEqual(next(df_spans), (CellComponentType.TEXT, 90, 143))
self.assertEqual(cell[90:143].strip(), "Only showing the last 20 rows")
@@ -235,7 +263,7 @@ def test_cell_components_iter(self):
Random stuff at the end
"""
df_spans = cell_components_iter(cell)
-
+
self.assertEqual(next(df_spans), (CellComponentType.TEXT, 0, 45))
self.assertEqual(cell[0:45].strip(), "Random stuff at the start")
self.assertEqual(next(df_spans), (CellComponentType.DF, 45, 247))
@@ -247,7 +275,7 @@ def test_cell_components_iter(self):
self.assertEqual(next(df_spans), (CellComponentType.TEXT, 495, 553))
self.assertEqual(cell[495:553].strip(), "Random stuff at the end")
-
+
with self.assertRaises(StopIteration):
next(df_spans)
@@ -257,7 +285,7 @@ def test_cell_components_iter(self):
"""
df_spans = cell_components_iter(cell)
-
+
self.assertEqual(next(df_spans), (CellComponentType.TEXT, 0, len(cell)))
with self.assertRaises(StopIteration):
next(df_spans)
@@ -272,7 +300,9 @@ def test_output_html_parser(self):
self.assertEqual(CellOutputHtmlParser.to_html(cell), "")
cell = "Some random text"
- self.assertEqual(CellOutputHtmlParser.to_html(cell), "Some random text
")
+ self.assertEqual(
+ CellOutputHtmlParser.to_html(cell), "Some random text
"
+ )
cell = """
@@ -298,18 +328,23 @@ def test_output_html_parser(self):
Random stuff at the end
"""
self.maxDiff = 1200
- self.assertEqual(CellOutputHtmlParser.to_html(cell),
- ("Random stuff at the start
"
- "id | animal |
"
- "1 | bat |
"
- "2 | mouse |
"
- "3 | horse |
"
- "
Random stuff in the middle
"
- "id | animal |
"
- "1 | cat |
"
- "2 | couse |
"
- "3 | corse |
"
- "
Random stuff at the end
"))
-
-if __name__ == '__main__':
+ self.assertEqual(
+ CellOutputHtmlParser.to_html(cell),
+ (
+ "Random stuff at the start
"
+ "id | animal |
"
+ "1 | bat |
"
+ "2 | mouse |
"
+ "3 | horse |
"
+ "
Random stuff in the middle
"
+ "id | animal |
"
+ "1 | cat |
"
+ "2 | couse |
"
+ "3 | corse |
"
+ "
Random stuff at the end
"
+ ),
+ )
+
+
+if __name__ == "__main__":
unittest.main()
diff --git a/sparkmagic/sparkmagic/utils/configuration.py b/sparkmagic/sparkmagic/utils/configuration.py
index 515492b43..68a9a95ec 100644
--- a/sparkmagic/sparkmagic/utils/configuration.py
+++ b/sparkmagic/sparkmagic/utils/configuration.py
@@ -2,18 +2,33 @@
import copy
import sys
import base64
-from hdijupyterutils.constants import EVENTS_HANDLER_CLASS_NAME, LOGGING_CONFIG_CLASS_NAME
+from hdijupyterutils.constants import (
+ EVENTS_HANDLER_CLASS_NAME,
+ LOGGING_CONFIG_CLASS_NAME,
+)
from hdijupyterutils.utils import join_paths
from hdijupyterutils.configuration import override as _override
from hdijupyterutils.configuration import override_all as _override_all
from hdijupyterutils.configuration import with_override
-from .constants import HOME_PATH, CONFIG_FILE, MAGICS_LOGGER_NAME, LIVY_KIND_PARAM, \
- LANG_SCALA, LANG_PYTHON, LANG_R, \
- SESSION_KIND_SPARKR, SESSION_KIND_SPARK, SESSION_KIND_PYSPARK, CONFIGURABLE_RETRY, \
- NO_AUTH, AUTH_BASIC
+from .constants import (
+ HOME_PATH,
+ CONFIG_FILE,
+ MAGICS_LOGGER_NAME,
+ LIVY_KIND_PARAM,
+ LANG_SCALA,
+ LANG_PYTHON,
+ LANG_R,
+ SESSION_KIND_SPARKR,
+ SESSION_KIND_SPARK,
+ SESSION_KIND_PYSPARK,
+ CONFIGURABLE_RETRY,
+ NO_AUTH,
+ AUTH_BASIC,
+)
from sparkmagic.livyclientlib.exceptions import BadUserConfigurationException
-#import sparkmagic.utils.constants as constants
+
+# import sparkmagic.utils.constants as constants
from requests_kerberos import REQUIRED
@@ -22,7 +37,7 @@
d = {}
path = join_paths(HOME_PATH, CONFIG_FILE)
-
+
def override(config, value):
_override(d, path, config, value)
@@ -36,6 +51,7 @@ def override_all(obj):
# Helpers
+
def get_livy_kind(language):
if language == LANG_SCALA:
return SESSION_KIND_SPARK
@@ -44,26 +60,29 @@ def get_livy_kind(language):
elif language == LANG_R:
return SESSION_KIND_SPARKR
else:
- raise BadUserConfigurationException("Cannot get session kind for {}.".format(language))
+ raise BadUserConfigurationException(
+ "Cannot get session kind for {}.".format(language)
+ )
def get_auth_value(username, password):
- if username == '' and password == '':
+ if username == "" and password == "":
return NO_AUTH
return AUTH_BASIC
@_with_override
def authenticators():
- return {
- u"Kerberos": u"sparkmagic.auth.kerberos.Kerberos",
- u"None": u"sparkmagic.auth.customauth.Authenticator",
- u"Basic_Access": u"sparkmagic.auth.basic.Basic"
+ return {
+ "Kerberos": "sparkmagic.auth.kerberos.Kerberos",
+ "None": "sparkmagic.auth.customauth.Authenticator",
+ "Basic_Access": "sparkmagic.auth.basic.Basic",
}
-
-
+
+
# Configs
+
def get_session_properties(language):
properties = copy.deepcopy(session_configs())
properties[LIVY_KIND_PARAM] = get_livy_kind(language)
@@ -77,9 +96,14 @@ def session_configs():
@_with_override
def kernel_python_credentials():
- return {u'username': u'', u'base64_password': u'', u'url': u'http://localhost:8998', u'auth': NO_AUTH}
-
-
+ return {
+ "username": "",
+ "base64_password": "",
+ "url": "http://localhost:8998",
+ "auth": NO_AUTH,
+ }
+
+
def base64_kernel_python_credentials():
return _credentials_override(kernel_python_credentials)
@@ -96,15 +120,26 @@ def base64_kernel_python3_credentials():
@_with_override
def kernel_scala_credentials():
- return {u'username': u'', u'base64_password': u'', u'url': u'http://localhost:8998', u'auth': NO_AUTH}
+ return {
+ "username": "",
+ "base64_password": "",
+ "url": "http://localhost:8998",
+ "auth": NO_AUTH,
+ }
-def base64_kernel_scala_credentials():
+def base64_kernel_scala_credentials():
return _credentials_override(kernel_scala_credentials)
+
@_with_override
def kernel_r_credentials():
- return {u'username': u'', u'base64_password': u'', u'url': u'http://localhost:8998', u'auth': NO_AUTH}
+ return {
+ "username": "",
+ "base64_password": "",
+ "url": "http://localhost:8998",
+ "auth": NO_AUTH,
+ }
def base64_kernel_r_credentials():
@@ -114,27 +149,27 @@ def base64_kernel_r_credentials():
@_with_override
def logging_config():
return {
- u"version": 1,
- u"formatters": {
- u"magicsFormatter": {
- u"format": u"%(asctime)s\t%(levelname)s\t%(message)s",
- u"datefmt": u""
+ "version": 1,
+ "formatters": {
+ "magicsFormatter": {
+ "format": "%(asctime)s\t%(levelname)s\t%(message)s",
+ "datefmt": "",
}
},
- u"handlers": {
- u"magicsHandler": {
- u"class": LOGGING_CONFIG_CLASS_NAME,
- u"formatter": u"magicsFormatter",
- u"home_path": HOME_PATH
+ "handlers": {
+ "magicsHandler": {
+ "class": LOGGING_CONFIG_CLASS_NAME,
+ "formatter": "magicsFormatter",
+ "home_path": HOME_PATH,
}
},
- u"loggers": {
+ "loggers": {
MAGICS_LOGGER_NAME: {
- u"handlers": [u"magicsHandler"],
- u"level": u"DEBUG",
- u"propagate": 0
+ "handlers": ["magicsHandler"],
+ "level": "DEBUG",
+ "propagate": 0,
}
- }
+ },
}
@@ -155,7 +190,7 @@ def livy_session_startup_timeout_seconds():
@_with_override
def fatal_error_suggestion():
- return u"""The code failed because of a fatal error:
+ return """The code failed because of a fatal error:
\t{}.
Some things to try:
@@ -201,7 +236,7 @@ def default_samplefraction():
@_with_override
def pyspark_dataframe_encoding():
- return u'utf-8'
+ return "utf-8"
@_with_override
@@ -267,9 +302,7 @@ def cleanup_all_sessions_on_exit():
@_with_override
def kerberos_auth_configuration():
- return {
- "mutual_authentication": REQUIRED
- }
+ return {"mutual_authentication": REQUIRED}
def _credentials_override(f):
@@ -278,15 +311,26 @@ def _credentials_override(f):
If 'base64_password' is not set, it will fallback to 'password' in config.
"""
credentials = f()
- base64_decoded_credentials = {k: credentials.get(k) for k in ('username', 'password', 'url', 'auth')}
- base64_password = credentials.get('base64_password')
+ base64_decoded_credentials = {
+ k: credentials.get(k) for k in ("username", "password", "url", "auth")
+ }
+ base64_password = credentials.get("base64_password")
if base64_password is not None:
try:
- base64_decoded_credentials['password'] = base64.b64decode(base64_password).decode()
+ base64_decoded_credentials["password"] = base64.b64decode(
+ base64_password
+ ).decode()
except Exception:
exception_type, exception, traceback = sys.exc_info()
- msg = "base64_password for %s contains invalid base64 string: %s %s" % (f.__name__, exception_type, exception)
+ msg = "base64_password for %s contains invalid base64 string: %s %s" % (
+ f.__name__,
+ exception_type,
+ exception,
+ )
raise BadUserConfigurationException(msg)
- if base64_decoded_credentials['auth'] is None:
- base64_decoded_credentials['auth'] = get_auth_value(base64_decoded_credentials['username'], base64_decoded_credentials['password'])
+ if base64_decoded_credentials["auth"] is None:
+ base64_decoded_credentials["auth"] = get_auth_value(
+ base64_decoded_credentials["username"],
+ base64_decoded_credentials["password"],
+ )
return base64_decoded_credentials
diff --git a/sparkmagic/sparkmagic/utils/constants.py b/sparkmagic/sparkmagic/utils/constants.py
index 10fc7e926..59b7c099c 100644
--- a/sparkmagic/sparkmagic/utils/constants.py
+++ b/sparkmagic/sparkmagic/utils/constants.py
@@ -9,7 +9,11 @@
SESSION_KIND_SPARK = "spark"
SESSION_KIND_PYSPARK = "pyspark"
SESSION_KIND_SPARKR = "sparkr"
-SESSION_KINDS_SUPPORTED = [SESSION_KIND_SPARK, SESSION_KIND_PYSPARK, SESSION_KIND_SPARKR]
+SESSION_KINDS_SUPPORTED = [
+ SESSION_KIND_SPARK,
+ SESSION_KIND_PYSPARK,
+ SESSION_KIND_SPARKR,
+]
LIBRARY_LOADED_EVENT = "notebookLoaded"
CLUSTER_CHANGE_EVENT = "notebookClusterChange"
@@ -72,34 +76,57 @@
KILLED_SESSION_STATUS = "killed"
RECOVERING_SESSION_STATUS = "recovering"
-POSSIBLE_SESSION_STATUS = [NOT_STARTED_SESSION_STATUS, IDLE_SESSION_STATUS, STARTING_SESSION_STATUS,
- BUSY_SESSION_STATUS, ERROR_SESSION_STATUS, DEAD_SESSION_STATUS,
- SUCCESS_SESSION_STATUS, SHUT_DOWN_SESSION_STATUS, RUNNING_SESSION_STATUS,
- KILLED_SESSION_STATUS, RECOVERING_SESSION_STATUS]
-FINAL_STATUS = [DEAD_SESSION_STATUS, ERROR_SESSION_STATUS, SUCCESS_SESSION_STATUS,
- KILLED_SESSION_STATUS]
+POSSIBLE_SESSION_STATUS = [
+ NOT_STARTED_SESSION_STATUS,
+ IDLE_SESSION_STATUS,
+ STARTING_SESSION_STATUS,
+ BUSY_SESSION_STATUS,
+ ERROR_SESSION_STATUS,
+ DEAD_SESSION_STATUS,
+ SUCCESS_SESSION_STATUS,
+ SHUT_DOWN_SESSION_STATUS,
+ RUNNING_SESSION_STATUS,
+ KILLED_SESSION_STATUS,
+ RECOVERING_SESSION_STATUS,
+]
+FINAL_STATUS = [
+ DEAD_SESSION_STATUS,
+ ERROR_SESSION_STATUS,
+ SUCCESS_SESSION_STATUS,
+ KILLED_SESSION_STATUS,
+]
ERROR_STATEMENT_STATUS = "error"
CANCELLED_STATEMENT_STATUS = "cancelled"
AVAILABLE_STATEMENT_STATUS = "available"
-FINAL_STATEMENT_STATUS = [ERROR_STATEMENT_STATUS, CANCELLED_STATEMENT_STATUS, AVAILABLE_STATEMENT_STATUS]
+FINAL_STATEMENT_STATUS = [
+ ERROR_STATEMENT_STATUS,
+ CANCELLED_STATEMENT_STATUS,
+ AVAILABLE_STATEMENT_STATUS,
+]
DELETE_SESSION_ACTION = "delete"
START_SESSION_ACTION = "start"
DO_NOTHING_ACTION = "nothing"
-INTERNAL_ERROR_MSG = "An internal error was encountered.\n" \
- "Please file an issue at https://github.com/jupyter-incubator/sparkmagic\nError:\n{}"
+INTERNAL_ERROR_MSG = (
+ "An internal error was encountered.\n"
+ "Please file an issue at https://github.com/jupyter-incubator/sparkmagic\nError:\n{}"
+)
EXPECTED_ERROR_MSG = "An error was encountered:\n{}"
YARN_RESOURCE_LIMIT_MSG = "Queue's AM resource limit exceeded."
-RESOURCE_LIMIT_WARNING = "Warning: The Spark session does not have enough YARN resources to start. {}"
+RESOURCE_LIMIT_WARNING = (
+ "Warning: The Spark session does not have enough YARN resources to start. {}"
+)
COMMAND_INTERRUPTED_MSG = "Interrupted by user"
-COMMAND_CANCELLATION_FAILED_MSG = "Interrupted by user but Livy failed to cancel the Spark statement. "\
- "The Livy session might have become unusable."
+COMMAND_CANCELLATION_FAILED_MSG = (
+ "Interrupted by user but Livy failed to cancel the Spark statement. "
+ "The Livy session might have become unusable."
+)
-LIVY_HEARTBEAT_TIMEOUT_PARAM = u"heartbeatTimeoutInSecond"
-LIVY_KIND_PARAM = u"kind"
+LIVY_HEARTBEAT_TIMEOUT_PARAM = "heartbeatTimeoutInSecond"
+LIVY_KIND_PARAM = "kind"
NO_AUTH = "None"
AUTH_KERBEROS = "Kerberos"
diff --git a/sparkmagic/sparkmagic/utils/dataframe_parser.py b/sparkmagic/sparkmagic/utils/dataframe_parser.py
index 993fb4882..03d842f88 100644
--- a/sparkmagic/sparkmagic/utils/dataframe_parser.py
+++ b/sparkmagic/sparkmagic/utils/dataframe_parser.py
@@ -1,6 +1,6 @@
import re
-import itertools
-from collections import OrderedDict
+import itertools
+from collections import OrderedDict
from enum import Enum
from functools import partial
@@ -33,22 +33,28 @@
"""
-header_top_pattern = r'[ \t\f\v]*(?P\+[-+]*\+)[\n\r]?'
-header_content_pattern = r'[ \t\f\v]*(?P\|.*\|)[\n\r]'
-header_bottom_pattern = r'(?P[ \t\f\v]*\+[-+]*\+[\n\r])'
-row_content_pattern = r'([ \t\f\v]*\|.*\|[\n\r])*'
-footer_pattern = r'(?P