diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..fdd2777c7 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,4 @@ +# .git-blame-ignore-revs +# Re-formatted entire code base with black +7ebf0753485c931db4135953dcd0864b4d089ed5 + diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a9c6d249a..6c50ace97 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -3,6 +3,7 @@ ### Checklist - [ ] Wrote a description of my changes above +- [ ] Formatted my code with [`black`](https://black.readthedocs.io/en/stable/index.html) - [ ] Added a bullet point for my changes to the top of the `CHANGELOG.md` file - [ ] Added or modified unit tests to reflect my changes - [ ] Manually tested with a notebook diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 000000000..b04fb15cb --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable diff --git a/CHANGELOG.md b/CHANGELOG.md index d51ace3e8..86b755c27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ ### Updates * Officially drop support for Python 3.6. Sparkmagic will not guarantee Python 3.6 compatibility moving forward. +* Re-format all code with [`black`](https://black.readthedocs.io/en/stable/index.html) and validate via CI + ### Bug Fixes diff --git a/README.md b/README.md index 6e01367f6..f059e26da 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ -[![Build Status](https://travis-ci.org/jupyter-incubator/sparkmagic.svg?branch=master)](https://travis-ci.org/jupyter-incubator/sparkmagic) [![Join the chat at https://gitter.im/sparkmagic/Lobby](https://badges.gitter.im/sparkmagic/Lobby.svg)](https://gitter.im/sparkmagic/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Build Status](https://travis-ci.org/jupyter-incubator/sparkmagic.svg?branch=master)](https://travis-ci.org/jupyter-incubator/sparkmagic) [![Join the chat at https://gitter.im/sparkmagic/Lobby](https://badges.gitter.im/sparkmagic/Lobby.svg)](https://gitter.im/sparkmagic/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + # sparkmagic diff --git a/autovizwidget/autovizwidget/plotlygraphs/bargraph.py b/autovizwidget/autovizwidget/plotlygraphs/bargraph.py index 61e671beb..4de275e4e 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/bargraph.py +++ b/autovizwidget/autovizwidget/plotlygraphs/bargraph.py @@ -10,4 +10,3 @@ class BarGraph(GraphBase): def _get_data(self, df, encoding): x_values, y_values = GraphBase._get_x_y_values(df, encoding) return [Bar(x=x_values, y=y_values)] - diff --git a/autovizwidget/autovizwidget/plotlygraphs/datagraph.py b/autovizwidget/autovizwidget/plotlygraphs/datagraph.py index 19021a23d..358413d07 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/datagraph.py +++ b/autovizwidget/autovizwidget/plotlygraphs/datagraph.py @@ -8,6 +8,7 @@ class DataGraph(object): """This does not use the table version of plotly because it freezes up the browser for >60 rows. Instead, we use pandas df HTML representation.""" + def __init__(self, display=None): if display is None: self.display = IpythonDisplay() @@ -21,7 +22,8 @@ def render(self, df, encoding, output): show_dimensions = pd.get_option("display.show_dimensions") # This will hide the index column for pandas df. - self.display.html(""" + self.display.html( + """ -""") - self.display.html(df.to_html(max_rows=max_rows, max_cols=max_cols, - show_dimensions=show_dimensions, notebook=True, classes="hideme")) +""" + ) + self.display.html( + df.to_html( + max_rows=max_rows, + max_cols=max_cols, + show_dimensions=show_dimensions, + notebook=True, + classes="hideme", + ) + ) @staticmethod def display_logarithmic_x_axis(): diff --git a/autovizwidget/autovizwidget/plotlygraphs/graphbase.py b/autovizwidget/autovizwidget/plotlygraphs/graphbase.py index 9b46f906c..7e09fe8b7 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/graphbase.py +++ b/autovizwidget/autovizwidget/plotlygraphs/graphbase.py @@ -3,6 +3,7 @@ from plotly.graph_objs import Figure, Layout from plotly.offline import iplot + try: from pandas.core.base import DataError except: @@ -29,16 +30,20 @@ def render(self, df, encoding, output): type_x_axis = self._get_type_axis(encoding.logarithmic_x_axis) type_y_axis = self._get_type_axis(encoding.logarithmic_y_axis) - layout = Layout(xaxis=dict(type=type_x_axis, rangemode="tozero", title=encoding.x), - yaxis=dict(type=type_y_axis, rangemode="tozero", title=encoding.y)) + layout = Layout( + xaxis=dict(type=type_x_axis, rangemode="tozero", title=encoding.x), + yaxis=dict(type=type_y_axis, rangemode="tozero", title=encoding.y), + ) with output: try: fig = Figure(data=data, layout=layout) iplot(fig, show_link=False) except TypeError: - print("\n\n\nPlease select another set of X and Y axis, because the type of the current axis do\n" - "not support aggregation over it.") + print( + "\n\n\nPlease select another set of X and Y axis, because the type of the current axis do\n" + "not support aggregation over it." + ) @staticmethod def display_x(): @@ -68,10 +73,9 @@ def _get_data(self, df, encoding): @staticmethod def _get_x_y_values(df, encoding): try: - x_values, y_values = GraphBase._get_x_y_values_aggregated(df, - encoding.x, - encoding.y, - encoding.y_aggregation) + x_values, y_values = GraphBase._get_x_y_values_aggregated( + df, encoding.x, encoding.y, encoding.y_aggregation + ) except ValueError: x_values = GraphBase._get_x_values(df, encoding) y_values = GraphBase._get_y_values(df, encoding) @@ -99,8 +103,11 @@ def _get_x_y_values_aggregated(df, x_column, y_column, y_aggregation): try: df_grouped = df.groupby(x_column) except TypeError: - raise InvalidEncodingError("Cannot group by X column '{}' because of its type: '{}'." - .format(df[x_column].dtype)) + raise InvalidEncodingError( + "Cannot group by X column '{}' because of its type: '{}'.".format( + df[x_column].dtype + ) + ) else: try: if y_aggregation == Encoding.y_agg_avg: @@ -114,20 +121,30 @@ def _get_x_y_values_aggregated(df, x_column, y_column, y_aggregation): elif y_aggregation == Encoding.y_agg_count: df_transformed = df_grouped.count() else: - raise ValueError("Y aggregation '{}' not supported.".format(y_aggregation)) + raise ValueError( + "Y aggregation '{}' not supported.".format(y_aggregation) + ) except (DataError, ValueError) as err: - raise InvalidEncodingError("Cannot aggregate column '{}' with aggregation function '{}' because:\n\t'{}'." - .format(y_column, y_aggregation, err)) + raise InvalidEncodingError( + "Cannot aggregate column '{}' with aggregation function '{}' because:\n\t'{}'.".format( + y_column, y_aggregation, err + ) + ) except TypeError: - raise InvalidEncodingError("Cannot aggregate column '{}' with aggregation function '{}' because the type\n" - "cannot be aggregated over." - .format(y_column, y_aggregation)) + raise InvalidEncodingError( + "Cannot aggregate column '{}' with aggregation function '{}' because the type\n" + "cannot be aggregated over.".format(y_column, y_aggregation) + ) else: df_transformed = df_transformed.reset_index() if y_column not in df_transformed.columns: - raise InvalidEncodingError("Y column '{}' is not valid with aggregation function '{}'. Please select " - "a different\naggregation function.".format(y_column, y_aggregation)) + raise InvalidEncodingError( + "Y column '{}' is not valid with aggregation function '{}'. Please select " + "a different\naggregation function.".format( + y_column, y_aggregation + ) + ) x_values = df_transformed[x_column].tolist() y_values = df_transformed[y_column].tolist() diff --git a/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py b/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py index 82dbf8868..9d6ff1755 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py +++ b/autovizwidget/autovizwidget/plotlygraphs/graphrenderer.py @@ -14,7 +14,6 @@ class GraphRenderer(object): - @staticmethod def render(df, encoding, output): with output: diff --git a/autovizwidget/autovizwidget/plotlygraphs/linegraph.py b/autovizwidget/autovizwidget/plotlygraphs/linegraph.py index 35356bd8d..e3612f7b2 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/linegraph.py +++ b/autovizwidget/autovizwidget/plotlygraphs/linegraph.py @@ -7,7 +7,6 @@ class LineGraph(GraphBase): - def _get_data(self, df, encoding): x_values, y_values = GraphBase._get_x_y_values(df, encoding) return [Scatter(x=x_values, y=y_values)] diff --git a/autovizwidget/autovizwidget/plotlygraphs/piegraph.py b/autovizwidget/autovizwidget/plotlygraphs/piegraph.py index 10b1a3fb5..d383c924f 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/piegraph.py +++ b/autovizwidget/autovizwidget/plotlygraphs/piegraph.py @@ -3,11 +3,12 @@ from plotly.graph_objs import Pie, Figure from plotly.offline import iplot + try: from pandas.core.base import DataError except: from pandas.core.groupby import DataError - + import autovizwidget.utils.configuration as conf from .graphbase import GraphBase @@ -24,13 +25,19 @@ def render(df, encoding, output): values, labels = PieGraph._get_x_values_labels(df, encoding) except TypeError: with output: - print("\n\n\nCannot group by X selection because of its type: '{}'. Please select another column." - .format(df[encoding.x].dtype)) + print( + "\n\n\nCannot group by X selection because of its type: '{}'. Please select another column.".format( + df[encoding.x].dtype + ) + ) return except (ValueError, DataError): with output: - print("\n\n\nCannot group by X selection. Please select another column." - .format(df[encoding.x].dtype)) + print( + "\n\n\nCannot group by X selection. Please select another column.".format( + df[encoding.x].dtype + ) + ) if df.size == 0: print("\n\n\nCannot display a pie graph for an empty data set.") return @@ -42,9 +49,12 @@ def render(df, encoding, output): # 500 rows take ~15 s. # 100 rows is almost automatic. if len(values) > max_slices_pie_graph: - print("There's {} values in your pie graph, which would render the graph unresponsive.\n" - "Please select another X with at most {} possible values." - .format(len(values), max_slices_pie_graph)) + print( + "There's {} values in your pie graph, which would render the graph unresponsive.\n" + "Please select another X with at most {} possible values.".format( + len(values), max_slices_pie_graph + ) + ) else: data = [Pie(values=values, labels=labels)] diff --git a/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py b/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py index 92c88f175..ebf938886 100644 --- a/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py +++ b/autovizwidget/autovizwidget/plotlygraphs/scattergraph.py @@ -4,7 +4,6 @@ class ScatterGraph(GraphBase): - def _get_data(self, df, encoding): x_values, y_values = GraphBase._get_x_y_values(df, encoding) - return [Scatter(x=x_values, y=y_values, mode='markers')] + return [Scatter(x=x_values, y=y_values, mode="markers")] diff --git a/autovizwidget/autovizwidget/tests/test_autovizwidget.py b/autovizwidget/autovizwidget/tests/test_autovizwidget.py index 1f7b7a555..ee648b970 100644 --- a/autovizwidget/autovizwidget/tests/test_autovizwidget.py +++ b/autovizwidget/autovizwidget/tests/test_autovizwidget.py @@ -27,12 +27,14 @@ def _setup(): renderer.display_logarithmic_x_axis.return_value = True renderer.display_logarithmic_y_axis.return_value = True - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": 12}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": 0}, + {"buildingID": 2, "date": "6/1/14", "temp_diff": 11}, + {"buildingID": 0, "date": "6/1/15", "temp_diff": 5}, + {"buildingID": 1, "date": "6/1/16", "temp_diff": 19}, + {"buildingID": 2, "date": "6/1/17", "temp_diff": 32}, + ] df = pd.DataFrame(records) encoding = Encoding(chart_type="table", x="date", y="temp_diff") @@ -53,8 +55,16 @@ def _teardown(): @with_setup(_setup, _teardown) def test_on_render_viz(): - widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory, - encoding_widget, ipython_display, spark_events=spark_events, testing=True) + widget = AutoVizWidget( + df, + encoding, + renderer, + ipywidget_factory, + encoding_widget, + ipython_display, + spark_events=spark_events, + testing=True, + ) # on_render_viz is called in the constructor, so no need to call it here. output.clear_output.assert_called_once_with() @@ -76,59 +86,129 @@ def test_on_render_viz(): encoding._chart_type = Encoding.chart_type_scatter widget.on_render_viz() assert_equals(len(spark_events.emit_graph_render_event.mock_calls), 2) - assert_equals(spark_events.emit_graph_render_event.call_args, call(Encoding.chart_type_scatter)) + assert_equals( + spark_events.emit_graph_render_event.call_args, + call(Encoding.chart_type_scatter), + ) @with_setup(_setup, _teardown) def test_create_viz_types_buttons(): - df_single_column = pd.DataFrame([{u'buildingID': 0}]) - widget = AutoVizWidget(df_single_column, encoding, renderer, ipywidget_factory, - encoding_widget, ipython_display, spark_events=spark_events, testing=True) + df_single_column = pd.DataFrame([{"buildingID": 0}]) + widget = AutoVizWidget( + df_single_column, + encoding, + renderer, + ipywidget_factory, + encoding_widget, + ipython_display, + spark_events=spark_events, + testing=True, + ) # create_viz_types_buttons is called in the constructor, so no need to call it here. - assert call(description=Encoding.chart_type_table) in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_pie) in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_line) not in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_area) not in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_bar) not in ipywidget_factory.get_button.mock_calls + assert ( + call(description=Encoding.chart_type_table) + in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_pie) + in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_line) + not in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_area) + not in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_bar) + not in ipywidget_factory.get_button.mock_calls + ) spark_events.emit_graph_render_event.assert_called_once_with(encoding.chart_type) - widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory, - encoding_widget, ipython_display, spark_events=spark_events, testing=True) + widget = AutoVizWidget( + df, + encoding, + renderer, + ipywidget_factory, + encoding_widget, + ipython_display, + spark_events=spark_events, + testing=True, + ) # create_viz_types_buttons is called in the constructor, so no need to call it here. - assert call(description=Encoding.chart_type_table) in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_pie) in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_line) in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_area) in ipywidget_factory.get_button.mock_calls - assert call(description=Encoding.chart_type_bar) in ipywidget_factory.get_button.mock_calls + assert ( + call(description=Encoding.chart_type_table) + in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_pie) + in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_line) + in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_area) + in ipywidget_factory.get_button.mock_calls + ) + assert ( + call(description=Encoding.chart_type_bar) + in ipywidget_factory.get_button.mock_calls + ) @with_setup(_setup, _teardown) def test_create_viz_empty_df(): df = pd.DataFrame([]) - widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory, - encoding_widget, ipython_display, spark_events=spark_events, testing=True) + widget = AutoVizWidget( + df, + encoding, + renderer, + ipywidget_factory, + encoding_widget, + ipython_display, + spark_events=spark_events, + testing=True, + ) ipywidget_factory.get_button.assert_not_called() ipywidget_factory.get_html.assert_called_once_with("No results.") ipython_display.display.assert_called_with(ipywidget_factory.get_html.return_value) spark_events.emit_graph_render_event.assert_called_once_with(encoding.chart_type) + @with_setup(_setup, _teardown) def test_convert_to_displayable_dataframe(): - bool_df = pd.DataFrame([{u'bool_col': True, u'int_col': 0, u'float_col': 3.0}, - {u'bool_col': False, u'int_col': 100, u'float_col': 0.7}]) + bool_df = pd.DataFrame( + [ + {"bool_col": True, "int_col": 0, "float_col": 3.0}, + {"bool_col": False, "int_col": 100, "float_col": 0.7}, + ] + ) copy_of_df = bool_df.copy() - widget = AutoVizWidget(df, encoding, renderer, ipywidget_factory, - encoding_widget, ipython_display, spark_events=spark_events, testing=True) + widget = AutoVizWidget( + df, + encoding, + renderer, + ipywidget_factory, + encoding_widget, + ipython_display, + spark_events=spark_events, + testing=True, + ) result = AutoVizWidget._convert_to_displayable_dataframe(bool_df) # Ensure original DF not changed assert_frame_equal(bool_df, copy_of_df) - assert_series_equal(bool_df[u'int_col'], result[u'int_col']) - assert_series_equal(bool_df[u'float_col'], result[u'float_col']) - assert_equals(result.dtypes[u'bool_col'], object) - assert_equals(len(result[u'bool_col']), 2) - assert_equals(result[u'bool_col'][0], 'True') - assert_equals(result[u'bool_col'][1], 'False') + assert_series_equal(bool_df["int_col"], result["int_col"]) + assert_series_equal(bool_df["float_col"], result["float_col"]) + assert_equals(result.dtypes["bool_col"], object) + assert_equals(len(result["bool_col"]), 2) + assert_equals(result["bool_col"][0], "True") + assert_equals(result["bool_col"][1], "False") spark_events.emit_graph_render_event.assert_called_once_with(encoding.chart_type) diff --git a/autovizwidget/autovizwidget/tests/test_encodingwidget.py b/autovizwidget/autovizwidget/tests/test_encodingwidget.py index 622a5add7..3153375ec 100644 --- a/autovizwidget/autovizwidget/tests/test_encodingwidget.py +++ b/autovizwidget/autovizwidget/tests/test_encodingwidget.py @@ -16,12 +16,14 @@ def _setup(): global df, encoding, ipywidget_factory, change_hook - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12, u'\u263A': True}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0, u'\u263A': True}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11, u'\u263A': True}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5, u'\u263A': True}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19, u'\u263A': True}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32, u'\u263A': True}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": 12, "\u263A": True}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": 0, "\u263A": True}, + {"buildingID": 2, "date": "6/1/14", "temp_diff": 11, "\u263A": True}, + {"buildingID": 0, "date": "6/1/15", "temp_diff": 5, "\u263A": True}, + {"buildingID": 1, "date": "6/1/16", "temp_diff": 19, "\u263A": True}, + {"buildingID": 2, "date": "6/1/17", "temp_diff": 32, "\u263A": True}, + ] df = pd.DataFrame(records) encoding = Encoding(chart_type="table", x="date", y="temp_diff") @@ -38,12 +40,14 @@ def _teardown(): @with_setup(_setup, _teardown) def test_encoding_with_all_none_doesnt_throw(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": 12}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": 0}, + {"buildingID": 2, "date": "6/1/14", "temp_diff": 11}, + {"buildingID": 0, "date": "6/1/15", "temp_diff": 5}, + {"buildingID": 1, "date": "6/1/16", "temp_diff": 19}, + {"buildingID": 2, "date": "6/1/17", "temp_diff": 32}, + ] df = pd.DataFrame(records) encoding = Encoding() @@ -53,16 +57,47 @@ def test_encoding_with_all_none_doesnt_throw(): EncodingWidget(df, encoding, change_hook, ipywidget_factory, testing=True) - assert call(description='X', value=None, options={'date': 'date', 'temp_diff': 'temp_diff', '-': None, - 'buildingID': 'buildingID'}) \ + assert ( + call( + description="X", + value=None, + options={ + "date": "date", + "temp_diff": "temp_diff", + "-": None, + "buildingID": "buildingID", + }, + ) in ipywidget_factory.get_dropdown.mock_calls - assert call(description='Y', value=None, options={'date': 'date', 'temp_diff': 'temp_diff', '-': None, - 'buildingID': 'buildingID'}) \ + ) + assert ( + call( + description="Y", + value=None, + options={ + "date": "date", + "temp_diff": "temp_diff", + "-": None, + "buildingID": "buildingID", + }, + ) in ipywidget_factory.get_dropdown.mock_calls - assert call(description='Func.', value='none', options={'Max': 'Max', 'Sum': 'Sum', 'Avg': 'Avg', - '-': 'None', 'Min': 'Min', 'Count': 'Count'}) \ + ) + assert ( + call( + description="Func.", + value="none", + options={ + "Max": "Max", + "Sum": "Sum", + "Avg": "Avg", + "-": "None", + "Min": "Min", + "Count": "Count", + }, + ) in ipywidget_factory.get_dropdown.mock_calls - + ) @with_setup(_setup, _teardown) diff --git a/autovizwidget/autovizwidget/tests/test_plotlygraphs.py b/autovizwidget/autovizwidget/tests/test_plotlygraphs.py index 427ed1b45..5b4672660 100644 --- a/autovizwidget/autovizwidget/tests/test_plotlygraphs.py +++ b/autovizwidget/autovizwidget/tests/test_plotlygraphs.py @@ -16,51 +16,86 @@ def test_graph_base_display_methods(): def test_graphbase_get_x_y_values(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12, u"str": "str"}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0, u"str": "str"}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11, u"str": "str"}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5, u"str": "str"}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19, u"str": "str"}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32, u"str": "str"}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": 12, "str": "str"}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": 0, "str": "str"}, + {"buildingID": 2, "date": "6/1/14", "temp_diff": 11, "str": "str"}, + {"buildingID": 0, "date": "6/1/15", "temp_diff": 5, "str": "str"}, + {"buildingID": 1, "date": "6/1/16", "temp_diff": 19, "str": "str"}, + {"buildingID": 2, "date": "6/1/17", "temp_diff": 32, "str": "str"}, + ] df = pd.DataFrame(records) - expected_xs = [u'6/1/13', u'6/1/14', u'6/1/15', u'6/1/16', u'6/1/17'] - - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_sum) + expected_xs = ["6/1/13", "6/1/14", "6/1/15", "6/1/16", "6/1/17"] + + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_sum, + ) xs, yx = GraphBase._get_x_y_values(df, encoding) assert xs == expected_xs assert yx == [12, 11, 5, 19, 32] - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_avg) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_avg, + ) xs, yx = GraphBase._get_x_y_values(df, encoding) assert xs == expected_xs assert yx == [6, 11, 5, 19, 32] - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_max) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_max, + ) xs, yx = GraphBase._get_x_y_values(df, encoding) assert xs == expected_xs assert yx == [12, 11, 5, 19, 32] - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_min) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_min, + ) xs, yx = GraphBase._get_x_y_values(df, encoding) assert xs == expected_xs assert yx == [0, 11, 5, 19, 32] - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_none) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_none, + ) xs, yx = GraphBase._get_x_y_values(df, encoding) - assert xs == [u'6/1/13', u'6/1/13', u'6/1/14', u'6/1/15', u'6/1/16', u'6/1/17'] + assert xs == ["6/1/13", "6/1/13", "6/1/14", "6/1/15", "6/1/16", "6/1/17"] assert yx == [12, 0, 11, 5, 19, 32] try: - encoding = Encoding(chart_type=Encoding.chart_type_line, x="buildingID", y="date", - y_aggregation=Encoding.y_agg_avg) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="buildingID", + y="date", + y_aggregation=Encoding.y_agg_avg, + ) GraphBase._get_x_y_values(df, encoding) assert False except InvalidEncodingError: pass try: - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="str", - y_aggregation=Encoding.y_agg_avg) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="str", + y_aggregation=Encoding.y_agg_avg, + ) GraphBase._get_x_y_values(df, encoding) assert False except InvalidEncodingError: @@ -75,21 +110,33 @@ def test_pie_graph_display_methods(): def test_pie_graph_get_values_labels(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": 12}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": 0}, + {"buildingID": 2, "date": "6/1/14", "temp_diff": 11}, + {"buildingID": 0, "date": "6/1/15", "temp_diff": 5}, + {"buildingID": 1, "date": "6/1/16", "temp_diff": 19}, + {"buildingID": 2, "date": "6/1/17", "temp_diff": 32}, + ] df = pd.DataFrame(records) - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y=None, y_aggregation=Encoding.y_agg_sum) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y=None, + y_aggregation=Encoding.y_agg_sum, + ) values, labels = PieGraph._get_x_values_labels(df, encoding) assert values == [2, 1, 1, 1, 1] assert labels == ["6/1/13", "6/1/14", "6/1/15", "6/1/16", "6/1/17"] - - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_sum) + + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_sum, + ) values, labels = PieGraph._get_x_values_labels(df, encoding) @@ -98,14 +145,21 @@ def test_pie_graph_get_values_labels(): def test_data_graph_render(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": 12}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": 0}, + {"buildingID": 2, "date": "6/1/14", "temp_diff": 11}, + {"buildingID": 0, "date": "6/1/15", "temp_diff": 5}, + {"buildingID": 1, "date": "6/1/16", "temp_diff": 19}, + {"buildingID": 2, "date": "6/1/17", "temp_diff": 32}, + ] df = pd.DataFrame(records) - encoding = Encoding(chart_type=Encoding.chart_type_line, x="date", y="temp_diff", y_aggregation=Encoding.y_agg_sum) + encoding = Encoding( + chart_type=Encoding.chart_type_line, + x="date", + y="temp_diff", + y_aggregation=Encoding.y_agg_sum, + ) display = MagicMock() data = DataGraph(display) diff --git a/autovizwidget/autovizwidget/tests/test_sparkevents.py b/autovizwidget/autovizwidget/tests/test_sparkevents.py index 33493901c..73de4bb01 100644 --- a/autovizwidget/autovizwidget/tests/test_sparkevents.py +++ b/autovizwidget/autovizwidget/tests/test_sparkevents.py @@ -24,29 +24,33 @@ def _teardown(): @with_setup(_setup, _teardown) def test_not_emit_graph_render_event_when_not_registered(): event_name = GRAPH_RENDER_EVENT - graph_type = 'Bar' + graph_type = "Bar" - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (GRAPH_TYPE, graph_type)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (GRAPH_TYPE, graph_type), + ] events.emit_graph_render_event(graph_type) events.get_utc_date_time.assert_called_with() assert not events.handler.handle_event.called - - + + @with_setup(_setup, _teardown) def test_emit_graph_render_event_when_registered(): conf.override(conf.events_handler.__name__, events.handler) event_name = GRAPH_RENDER_EVENT - graph_type = 'Bar' - - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (GRAPH_TYPE, graph_type)] + graph_type = "Bar" + + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (GRAPH_TYPE, graph_type), + ] events.emit_graph_render_event(graph_type) diff --git a/autovizwidget/autovizwidget/tests/test_utils.py b/autovizwidget/autovizwidget/tests/test_utils.py index a26f98e1f..9abb29ad1 100644 --- a/autovizwidget/autovizwidget/tests/test_utils.py +++ b/autovizwidget/autovizwidget/tests/test_utils.py @@ -12,12 +12,50 @@ def _setup(): global df, encoding - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': 12, "mystr": "alejandro", "mystr2": "1"}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': 0, "mystr": "alejandro", "mystr2": "1"}, - {u'buildingID': 2, u'date': u'6/1/14', u'temp_diff': 11, "mystr": "alejandro", "mystr2": "1"}, - {u'buildingID': 0, u'date': u'6/1/15', u'temp_diff': 5, "mystr": "alejandro", "mystr2": "1.0"}, - {u'buildingID': 1, u'date': u'6/1/16', u'temp_diff': 19, "mystr": "alejandro", "mystr2": "1"}, - {u'buildingID': 2, u'date': u'6/1/17', u'temp_diff': 32, "mystr": "alejandro", "mystr2": "1"}] + records = [ + { + "buildingID": 0, + "date": "6/1/13", + "temp_diff": 12, + "mystr": "alejandro", + "mystr2": "1", + }, + { + "buildingID": 1, + "date": "6/1/13", + "temp_diff": 0, + "mystr": "alejandro", + "mystr2": "1", + }, + { + "buildingID": 2, + "date": "6/1/14", + "temp_diff": 11, + "mystr": "alejandro", + "mystr2": "1", + }, + { + "buildingID": 0, + "date": "6/1/15", + "temp_diff": 5, + "mystr": "alejandro", + "mystr2": "1.0", + }, + { + "buildingID": 1, + "date": "6/1/16", + "temp_diff": 19, + "mystr": "alejandro", + "mystr2": "1", + }, + { + "buildingID": 2, + "date": "6/1/17", + "temp_diff": 32, + "mystr": "alejandro", + "mystr2": "1", + }, + ] df = pd.DataFrame(records) encoding = Encoding(chart_type="table", x="date", y="temp_diff") @@ -46,24 +84,27 @@ def _check(d, expected): x = utils.select_x(d) assert x == expected - data = dict(col1=[1.0, 2.0, 3.0], # Q - col2=['A', 'B', 'C'], # N - col3=pd.date_range('2012', periods=3, freq='A')) # T - _check(data, 'col3') + data = dict( + col1=[1.0, 2.0, 3.0], # Q + col2=["A", "B", "C"], # N + col3=pd.date_range("2012", periods=3, freq="A"), + ) # T + _check(data, "col3") - data = dict(col1=[1.0, 2.0, 3.0], # Q - col2=['A', 'B', 'C']) # N - _check(data, 'col2') + data = dict(col1=[1.0, 2.0, 3.0], col2=["A", "B", "C"]) # Q # N + _check(data, "col2") data = dict(col1=[1.0, 2.0, 3.0]) # Q - _check(data, 'col1') + _check(data, "col1") # Custom order - data = dict(col1=[1.0, 2.0, 3.0], # Q - col2=['A', 'B', 'C'], # N - col3=pd.date_range('2012', periods=3, freq='A'), # T - col4=pd.date_range('2012', periods=3, freq='A')) # T - selected_x = utils.select_x(data, ['N', 'T', 'Q', 'O']) + data = dict( + col1=[1.0, 2.0, 3.0], # Q + col2=["A", "B", "C"], # N + col3=pd.date_range("2012", periods=3, freq="A"), # T + col4=pd.date_range("2012", periods=3, freq="A"), + ) # T + selected_x = utils.select_x(data, ["N", "T", "Q", "O"]) assert selected_x == "col2" # Len < 1 @@ -72,25 +113,31 @@ def _check(d, expected): def test_select_y(): def _check(d, expected): - x = 'col1' + x = "col1" y = utils.select_y(d, x) assert y == expected - data = dict(col1=[1.0, 2.0, 3.0], # Chosen X - col2=['A', 'B', 'C'], # N - col3=pd.date_range('2012', periods=3, freq='A'), # T - col4=pd.date_range('2012', periods=3, freq='A'), # T - col5=[1.0, 2.0, 3.0]) # Q - _check(data, 'col5') - - data = dict(col1=[1.0, 2.0, 3.0], # Chosen X - col2=['A', 'B', 'C'], # N - col3=pd.date_range('2012', periods=3, freq='A')) # T - _check(data, 'col2') - - data = dict(col1=[1.0, 2.0, 3.0], # Chosen X - col2=pd.date_range('2012', periods=3, freq='A')) # T - _check(data, 'col2') + data = dict( + col1=[1.0, 2.0, 3.0], # Chosen X + col2=["A", "B", "C"], # N + col3=pd.date_range("2012", periods=3, freq="A"), # T + col4=pd.date_range("2012", periods=3, freq="A"), # T + col5=[1.0, 2.0, 3.0], + ) # Q + _check(data, "col5") + + data = dict( + col1=[1.0, 2.0, 3.0], # Chosen X + col2=["A", "B", "C"], # N + col3=pd.date_range("2012", periods=3, freq="A"), + ) # T + _check(data, "col2") + + data = dict( + col1=[1.0, 2.0, 3.0], # Chosen X + col2=pd.date_range("2012", periods=3, freq="A"), + ) # T + _check(data, "col2") # No data assert utils.select_y(None, "something") is None @@ -102,12 +149,14 @@ def _check(d, expected): assert utils.select_y(df, None) is None # Custom order - data = dict(col1=[1.0, 2.0, 3.0], # Chosen X - col2=['A', 'B', 'C'], # N - col3=pd.date_range('2012', periods=3, freq='A'), # T - col4=pd.date_range('2012', periods=3, freq='A'), # T - col5=[1.0, 2.0, 3.0], # Q - col6=[1.0, 2.0, 3.0]) # Q - selected_x = 'col1' - selected_y = utils.select_y(data, selected_x, ['N', 'T', 'Q', 'O']) - assert selected_y == 'col2' + data = dict( + col1=[1.0, 2.0, 3.0], # Chosen X + col2=["A", "B", "C"], # N + col3=pd.date_range("2012", periods=3, freq="A"), # T + col4=pd.date_range("2012", periods=3, freq="A"), # T + col5=[1.0, 2.0, 3.0], # Q + col6=[1.0, 2.0, 3.0], + ) # Q + selected_x = "col1" + selected_y = utils.select_y(data, selected_x, ["N", "T", "Q", "O"]) + assert selected_y == "col2" diff --git a/autovizwidget/autovizwidget/utils/configuration.py b/autovizwidget/autovizwidget/utils/configuration.py index 1cd048525..7e58f28f5 100644 --- a/autovizwidget/autovizwidget/utils/configuration.py +++ b/autovizwidget/autovizwidget/utils/configuration.py @@ -1,12 +1,15 @@ # Distributed under the terms of the Modified BSD License. -from hdijupyterutils.constants import EVENTS_HANDLER_CLASS_NAME, LOGGING_CONFIG_CLASS_NAME +from hdijupyterutils.constants import ( + EVENTS_HANDLER_CLASS_NAME, + LOGGING_CONFIG_CLASS_NAME, +) from hdijupyterutils.utils import join_paths from hdijupyterutils.configuration import override as _override from hdijupyterutils.configuration import override_all as _override_all from hdijupyterutils.configuration import with_override from .constants import HOME_PATH, CONFIG_FILE - + d = {} path = join_paths(HOME_PATH, CONFIG_FILE) @@ -18,15 +21,18 @@ def override(config, value): def override_all(obj): _override_all(d, obj) - + + _with_override = with_override(d, path) - + # Configs + @_with_override def events_handler(): return None + @_with_override def max_slices_pie_graph(): return 100 diff --git a/autovizwidget/autovizwidget/utils/events.py b/autovizwidget/autovizwidget/utils/events.py index f3415ab90..87ee3a79a 100644 --- a/autovizwidget/autovizwidget/utils/events.py +++ b/autovizwidget/autovizwidget/utils/events.py @@ -17,9 +17,11 @@ def emit_graph_render_event(self, graph_type): event_name = GRAPH_RENDER_EVENT time_stamp = self.get_utc_date_time() - kwargs_list = [(EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (GRAPH_TYPE, graph_type)] - - if self.emit: + kwargs_list = [ + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (GRAPH_TYPE, graph_type), + ] + + if self.emit: self.send_to_handler(kwargs_list) diff --git a/autovizwidget/autovizwidget/widget/autovizwidget.py b/autovizwidget/autovizwidget/widget/autovizwidget.py index 280785e1e..0c70c3fec 100644 --- a/autovizwidget/autovizwidget/widget/autovizwidget.py +++ b/autovizwidget/autovizwidget/widget/autovizwidget.py @@ -13,13 +13,24 @@ class AutoVizWidget(Box): - def __init__(self, df, encoding, renderer=None, ipywidget_factory=None, encoding_widget=None, ipython_display=None, - nested_widget_mode=False, spark_events=None, testing=False, **kwargs): + def __init__( + self, + df, + encoding, + renderer=None, + ipywidget_factory=None, + encoding_widget=None, + ipython_display=None, + nested_widget_mode=False, + spark_events=None, + testing=False, + **kwargs + ): assert encoding is not None assert df is not None assert type(df) is pd.DataFrame - kwargs['orientation'] = 'vertical' + kwargs["orientation"] = "vertical" if not testing: super(AutoVizWidget, self).__init__((), **kwargs) @@ -74,14 +85,22 @@ def on_render_viz(self, *args): self.encoding_widget.show_x(self.renderer.display_x(self.encoding.chart_type)) self.encoding_widget.show_y(self.renderer.display_y(self.encoding.chart_type)) - self.encoding_widget.show_controls(self.renderer.display_controls(self.encoding.chart_type)) - self.encoding_widget.show_logarithmic_x_axis(self.renderer.display_logarithmic_x_axis(self.encoding.chart_type)) - self.encoding_widget.show_logarithmic_y_axis(self.renderer.display_logarithmic_y_axis(self.encoding.chart_type)) + self.encoding_widget.show_controls( + self.renderer.display_controls(self.encoding.chart_type) + ) + self.encoding_widget.show_logarithmic_x_axis( + self.renderer.display_logarithmic_x_axis(self.encoding.chart_type) + ) + self.encoding_widget.show_logarithmic_y_axis( + self.renderer.display_logarithmic_y_axis(self.encoding.chart_type) + ) if len(self.df) > 0: self.renderer.render(self.df, self.encoding, self.to_display) else: with self.to_display: - self.ipython_display.display(self.ipywidget_factory.get_html('No results.')) + self.ipython_display.display( + self.ipywidget_factory.get_html("No results.") + ) def _create_controls_widget(self): # Create types of viz hbox @@ -97,7 +116,9 @@ def _create_viz_types_buttons(self): children = list() if len(self.df) > 0: - self.heading = self.ipywidget_factory.get_html('Type:', width='80px', height='32px') + self.heading = self.ipywidget_factory.get_html( + "Type:", width="80px", height="32px" + ) children.append(self.heading) self._create_type_button(Encoding.chart_type_table, children) @@ -130,6 +151,6 @@ def _convert_to_displayable_dataframe(df): df = df.copy() # Convert all booleans to string because Plotly doesn't know how to plot booleans, # but it does know how to plot strings. - bool_columns = list(df.select_dtypes(include=['bool']).columns) + bool_columns = list(df.select_dtypes(include=["bool"]).columns) df[bool_columns] = df[bool_columns].astype(str) return df diff --git a/autovizwidget/autovizwidget/widget/encoding.py b/autovizwidget/autovizwidget/widget/encoding.py index 9c778aaa5..f2f57114f 100644 --- a/autovizwidget/autovizwidget/widget/encoding.py +++ b/autovizwidget/autovizwidget/widget/encoding.py @@ -9,7 +9,13 @@ class Encoding(object): chart_type_bar = "Bar" chart_type_pie = "Pie" chart_type_table = "Table" - supported_chart_types = [chart_type_line, chart_type_area, chart_type_bar, chart_type_pie, chart_type_table] + supported_chart_types = [ + chart_type_line, + chart_type_area, + chart_type_bar, + chart_type_pie, + chart_type_table, + ] y_agg_avg = "Avg" y_agg_min = "Min" @@ -17,10 +23,24 @@ class Encoding(object): y_agg_sum = "Sum" y_agg_none = "None" y_agg_count = "Count" - supported_y_agg = [y_agg_avg, y_agg_min, y_agg_max, y_agg_sum, y_agg_none, y_agg_count] - - def __init__(self, chart_type=None, x=None, y=None, y_aggregation=None, - logarithmic_x_axis=False, logarithmic_y_axis=False): + supported_y_agg = [ + y_agg_avg, + y_agg_min, + y_agg_max, + y_agg_sum, + y_agg_none, + y_agg_count, + ] + + def __init__( + self, + chart_type=None, + x=None, + y=None, + y_aggregation=None, + logarithmic_x_axis=False, + logarithmic_y_axis=False, + ): self._chart_type = chart_type self._x = x self._y = y diff --git a/autovizwidget/autovizwidget/widget/encodingwidget.py b/autovizwidget/autovizwidget/widget/encodingwidget.py index 5690d85ec..38bb0af6d 100644 --- a/autovizwidget/autovizwidget/widget/encodingwidget.py +++ b/autovizwidget/autovizwidget/widget/encodingwidget.py @@ -17,12 +17,14 @@ class EncodingWidget(Box): - def __init__(self, df, encoding, change_hook, ipywidget_factory=None, testing=False, **kwargs): + def __init__( + self, df, encoding, change_hook, ipywidget_factory=None, testing=False, **kwargs + ): assert encoding is not None assert df is not None assert type(df) is pd.DataFrame - kwargs['orientation'] = 'vertical' + kwargs["orientation"] = "vertical" if not testing: super(EncodingWidget, self).__init__((), **kwargs) @@ -36,36 +38,43 @@ def __init__(self, df, encoding, change_hook, ipywidget_factory=None, testing=Fa self.widget = self.ipywidget_factory.get_vbox() - self.title = self.ipywidget_factory.get_html('Encoding:', width='148px', height='32px') + self.title = self.ipywidget_factory.get_html( + "Encoding:", width="148px", height="32px" + ) # X view options_x_view = {text(i): text(i) for i in self.df.columns} options_x_view["-"] = None - self.x_view = self.ipywidget_factory.get_dropdown(options=options_x_view, - description="X", value=self.encoding.x) - self.x_view.on_trait_change(self._x_changed_callback, 'value') + self.x_view = self.ipywidget_factory.get_dropdown( + options=options_x_view, description="X", value=self.encoding.x + ) + self.x_view.on_trait_change(self._x_changed_callback, "value") self.x_view.layout.width = "200px" # Y options_y_view = {text(i): text(i) for i in self.df.columns} options_y_view["-"] = None - y_column_view = self.ipywidget_factory.get_dropdown(options=options_y_view, - description="Y", value=self.encoding.y) - y_column_view.on_trait_change(self._y_changed_callback, 'value') + y_column_view = self.ipywidget_factory.get_dropdown( + options=options_y_view, description="Y", value=self.encoding.y + ) + y_column_view.on_trait_change(self._y_changed_callback, "value") y_column_view.layout.width = "200px" # Y aggregator value_for_view = self._get_value_for_aggregation(self.encoding.y_aggregation) self.y_agg_view = self.ipywidget_factory.get_dropdown( - options={"-": Encoding.y_agg_none, - Encoding.y_agg_avg: Encoding.y_agg_avg, - Encoding.y_agg_min: Encoding.y_agg_min, - Encoding.y_agg_max: Encoding.y_agg_max, - Encoding.y_agg_sum: Encoding.y_agg_sum, - Encoding.y_agg_count: Encoding.y_agg_count}, + options={ + "-": Encoding.y_agg_none, + Encoding.y_agg_avg: Encoding.y_agg_avg, + Encoding.y_agg_min: Encoding.y_agg_min, + Encoding.y_agg_max: Encoding.y_agg_max, + Encoding.y_agg_sum: Encoding.y_agg_sum, + Encoding.y_agg_count: Encoding.y_agg_count, + }, description="Func.", - value=value_for_view) - self.y_agg_view.on_trait_change(self._y_agg_changed_callback, 'value') + value=value_for_view, + ) + self.y_agg_view.on_trait_change(self._y_agg_changed_callback, "value") self.y_agg_view.layout.width = "200px" # Y view @@ -74,15 +83,23 @@ def __init__(self, df, encoding, change_hook, ipywidget_factory=None, testing=Fa # Logarithmic X axis self.logarithmic_x_axis = self.ipywidget_factory.get_checkbox( - description="Log scale X", value=encoding.logarithmic_x_axis) + description="Log scale X", value=encoding.logarithmic_x_axis + ) self.logarithmic_x_axis.on_trait_change(self._logarithmic_x_callback, "value") # Logarithmic Y axis self.logarithmic_y_axis = self.ipywidget_factory.get_checkbox( - description="Log scale Y", value=encoding.logarithmic_y_axis) + description="Log scale Y", value=encoding.logarithmic_y_axis + ) self.logarithmic_y_axis.on_trait_change(self._logarithmic_y_callback, "value") - children = [self.title, self.x_view, self.y_view, self.logarithmic_x_axis, self.logarithmic_y_axis] + children = [ + self.title, + self.x_view, + self.y_view, + self.logarithmic_x_axis, + self.logarithmic_y_axis, + ] self.widget.children = children self.children = [self.widget] @@ -109,28 +126,28 @@ def _get_value_for_aggregation(self, y_aggregation): return "none" def _x_changed_callback(self, name, old_value, new_value): - self.encoding.x = new_value - return self.change_hook() + self.encoding.x = new_value + return self.change_hook() def _y_changed_callback(self, name, old_value, new_value): - self.encoding.y = new_value - return self.change_hook() + self.encoding.y = new_value + return self.change_hook() def _y_agg_changed_callback(self, name, old_value, new_value): - if new_value == "none": - self.encoding.y_aggregation = None - else: - self.encoding.y_aggregation = new_value - return self.change_hook() + if new_value == "none": + self.encoding.y_aggregation = None + else: + self.encoding.y_aggregation = new_value + return self.change_hook() def _logarithmic_x_callback(self, name, old_value, new_value): - self.encoding.logarithmic_x_axis = new_value - return self.change_hook() + self.encoding.logarithmic_x_axis = new_value + return self.change_hook() def _logarithmic_y_callback(self, name, old_value, new_value): - self.encoding.logarithmic_y_axis = new_value - return self.change_hook() - + self.encoding.logarithmic_y_axis = new_value + return self.change_hook() + def _widget_visible(self, widget, visible): if visible: widget.layout.display = "flex" diff --git a/autovizwidget/autovizwidget/widget/utils.py b/autovizwidget/autovizwidget/widget/utils.py index b79ebd91b..ec3fa55e3 100644 --- a/autovizwidget/autovizwidget/widget/utils.py +++ b/autovizwidget/autovizwidget/widget/utils.py @@ -15,16 +15,28 @@ def infer_vegalite_type(data): typ = pd.api.types.infer_dtype(data) - if typ in ['floating', 'mixed-integer-float', 'integer', - 'mixed-integer', 'complex']: - typecode = 'Q' - elif typ in ['string', 'bytes', 'categorical', 'boolean', 'mixed', 'unicode']: - typecode = 'N' - elif typ in ['datetime', 'datetime64', 'timedelta', - 'timedelta64', 'date', 'time', 'period']: - typecode = 'T' + if typ in [ + "floating", + "mixed-integer-float", + "integer", + "mixed-integer", + "complex", + ]: + typecode = "Q" + elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: + typecode = "N" + elif typ in [ + "datetime", + "datetime64", + "timedelta", + "timedelta64", + "date", + "time", + "period", + ]: + typecode = "T" else: - typecode = 'N' + typecode = "N" return typecode @@ -33,7 +45,7 @@ def _validate_custom_order(order): assert len(order) == 4 list_to_check = list(order) list_to_check.sort() - assert list_to_check == ['N', 'O', 'Q', 'T'] + assert list_to_check == ["N", "O", "Q", "T"] def _classify_data_by_type(data, order, skip=None): @@ -65,7 +77,7 @@ def select_x(data, order=None): return None if order is None: - order = ['T', 'O', 'N', 'Q'] + order = ["T", "O", "N", "Q"] else: _validate_custom_order(order) @@ -96,7 +108,7 @@ def select_y(data, x_name, order=None, aggregator=None): return None if order is None: - order = ['Q', 'O', 'N', 'T'] + order = ["Q", "O", "N", "T"] else: _validate_custom_order(order) @@ -115,6 +127,10 @@ def select_y(data, x_name, order=None, aggregator=None): def display_dataframe(df): selected_x = select_x(df) selected_y = select_y(df, selected_x) - encoding = Encoding(chart_type=Encoding.chart_type_table, x=selected_x, y=selected_y, - y_aggregation=Encoding.y_agg_max) + encoding = Encoding( + chart_type=Encoding.chart_type_table, + x=selected_x, + y=selected_y, + y_aggregation=Encoding.y_agg_max, + ) return AutoVizWidget(df, encoding) diff --git a/hdijupyterutils/hdijupyterutils/configuration.py b/hdijupyterutils/hdijupyterutils/configuration.py index a18a6a698..0cb47bec6 100644 --- a/hdijupyterutils/hdijupyterutils/configuration.py +++ b/hdijupyterutils/hdijupyterutils/configuration.py @@ -11,6 +11,7 @@ def with_override(overrides, path, fsrw_class=None): """A decorator which first initializes the overrided configurations, then checks the global overrided defaults for the given configuration, calling the function to get the default result otherwise.""" + def ret(f): def wrapped_f(*args): # Can access overrides and path here @@ -20,22 +21,22 @@ def wrapped_f(*args): return overrides[name] else: return f(*args) - + # Hack! We do this so that we can query the .__name__ of the function # later to get the name of the configuration dynamically, e.g. for unit tests wrapped_f.__name__ = f.__name__ return wrapped_f - + return ret - + def override(overrides, path, config, value, fsrw_class=None): """Given a string representing a configuration and a value for that configuration, override the configuration. Initialize the overrided configuration beforehand.""" _initialize(overrides, path, fsrw_class) overrides[config] = value - + def override_all(overrides, new_overrides): """Given a dictionary representing the overrided defaults for this configuration, initialize the global configuration.""" @@ -50,20 +51,20 @@ def _initialize(overrides, path, fsrw_class): if not overrides: new_overrides = _load(path, fsrw_class) override_all(overrides, new_overrides) - - + + def _load(path, fsrw_class=None): """Returns a dictionary of configuration by reading from the configuration file.""" if fsrw_class is None: fsrw_class = FileSystemReaderWriter - + config_file = fsrw_class(path) config_file.ensure_file_exists() config_text = config_file.read_lines() - line = u"".join(config_text).strip() - - if line == u"": + line = "".join(config_text).strip() + + if line == "": overrides = {} else: overrides = json.loads(line) diff --git a/hdijupyterutils/hdijupyterutils/constants.py b/hdijupyterutils/hdijupyterutils/constants.py index fdd4d8a7a..08a18b9b8 100644 --- a/hdijupyterutils/hdijupyterutils/constants.py +++ b/hdijupyterutils/hdijupyterutils/constants.py @@ -1,6 +1,6 @@ -LOGGING_CONFIG_CLASS_NAME = u"hdijupyterutils.filehandler.MagicsFileHandler" +LOGGING_CONFIG_CLASS_NAME = "hdijupyterutils.filehandler.MagicsFileHandler" -EVENTS_HANDLER_CLASS_NAME = u"hdijupyterutils.eventshandler.EventsHandler" +EVENTS_HANDLER_CLASS_NAME = "hdijupyterutils.eventshandler.EventsHandler" INSTANCE_ID = "InstanceId" TIMESTAMP = "Timestamp" EVENT_NAME = "EventName" diff --git a/hdijupyterutils/hdijupyterutils/filehandler.py b/hdijupyterutils/hdijupyterutils/filehandler.py index e5a5fc0f9..6f74580b9 100644 --- a/hdijupyterutils/hdijupyterutils/filehandler.py +++ b/hdijupyterutils/hdijupyterutils/filehandler.py @@ -6,15 +6,19 @@ class MagicsFileHandler(logging.FileHandler): """The default logging handler used by the magics; this behavior can be overridden by modifying the config file""" + def __init__(self, **kwargs): # Simply invokes the behavior of the superclass, but sets the filename keyword argument if it's not already set. - if 'filename' in kwargs: + if "filename" in kwargs: super(MagicsFileHandler, self).__init__(**kwargs) else: - magics_home_path = kwargs.pop(u"home_path") + magics_home_path = kwargs.pop("home_path") logs_folder_name = "logs" log_file_name = "log_{}.log".format(get_instance_id()) - directory = FileSystemReaderWriter(join_paths(magics_home_path, logs_folder_name)) + directory = FileSystemReaderWriter( + join_paths(magics_home_path, logs_folder_name) + ) directory.ensure_path_exists() - super(MagicsFileHandler, self).__init__(filename=join_paths(directory.path, log_file_name), **kwargs) - + super(MagicsFileHandler, self).__init__( + filename=join_paths(directory.path, log_file_name), **kwargs + ) diff --git a/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py b/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py index e948cac2b..5ede09cf6 100644 --- a/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py +++ b/hdijupyterutils/hdijupyterutils/filesystemreaderwriter.py @@ -4,9 +4,9 @@ class FileSystemReaderWriter(object): - def __init__(self, path): from .utils import expand_path + assert path is not None self.path = expand_path(path) @@ -16,7 +16,7 @@ def ensure_path_exists(self): def ensure_file_exists(self): self._ensure_path_exists(os.path.dirname(self.path)) if not os.path.exists(self.path): - open(self.path, 'w').close() + open(self.path, "w").close() def read_lines(self): if os.path.isfile(self.path): diff --git a/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py b/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py index 23164872f..249857f1a 100644 --- a/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py +++ b/hdijupyterutils/hdijupyterutils/ipywidgetfactory.py @@ -1,7 +1,20 @@ # Copyright (c) 2015 aggftw@gmail.com # Distributed under the terms of the Modified BSD License. -from ipywidgets import VBox, Output, Button, HTML, HBox, Dropdown, Checkbox, ToggleButtons, Text, Textarea, Tab, Password +from ipywidgets import ( + VBox, + Output, + Button, + HTML, + HBox, + Dropdown, + Checkbox, + ToggleButtons, + Text, + Textarea, + Tab, + Password, +) class IpyWidgetFactory(object): diff --git a/hdijupyterutils/hdijupyterutils/log.py b/hdijupyterutils/hdijupyterutils/log.py index 3e9b3dfaf..943403325 100644 --- a/hdijupyterutils/hdijupyterutils/log.py +++ b/hdijupyterutils/hdijupyterutils/log.py @@ -9,9 +9,10 @@ class Log(object): """Logger for magics. A small wrapper class around the configured logger described in the configuration file""" + def __init__(self, logger_name, logging_config, caller_name): logging.config.dictConfig(logging_config) - + assert caller_name is not None self._caller_name = caller_name self.logger_name = logger_name @@ -30,30 +31,30 @@ def _getLogger(self): self.logger = logging.getLogger(self.logger_name) def _transform_log_message(self, message): - return u'{}\t{}'.format(self._caller_name, message) + return "{}\t{}".format(self._caller_name, message) def logging_config(): return { - u"version": 1, - u"formatters": { - u"magicsFormatter": { - u"format": u"%(asctime)s\t%(levelname)s\t%(message)s", - u"datefmt": u"" + "version": 1, + "formatters": { + "magicsFormatter": { + "format": "%(asctime)s\t%(levelname)s\t%(message)s", + "datefmt": "", } }, - u"handlers": { - u"magicsHandler": { - u"class": LOGGING_CONFIG_CLASS_NAME, - u"formatter": u"magicsFormatter", - u"home_path": "~/.hdijupyterutils" + "handlers": { + "magicsHandler": { + "class": LOGGING_CONFIG_CLASS_NAME, + "formatter": "magicsFormatter", + "home_path": "~/.hdijupyterutils", } }, - u"loggers": { - u"magicsLogger": { - u"handlers": [u"magicsHandler"], - u"level": u"DEBUG", - u"propagate": 0 + "loggers": { + "magicsLogger": { + "handlers": ["magicsHandler"], + "level": "DEBUG", + "propagate": 0, } - } - } + }, + } diff --git a/hdijupyterutils/hdijupyterutils/tests/test_configuration.py b/hdijupyterutils/hdijupyterutils/tests/test_configuration.py index fb43024db..d391cdd80 100644 --- a/hdijupyterutils/hdijupyterutils/tests/test_configuration.py +++ b/hdijupyterutils/hdijupyterutils/tests/test_configuration.py @@ -11,7 +11,7 @@ path = "~/.testing/config.json" original_value = 0 - + def module_override(config, value): global d, path override(d, path, config, value) @@ -27,35 +27,35 @@ def module_override_all(obj): def my_config(): global original_value return original_value - - + + @with_override(d, path) def my_config_2(): global original_value - return original_value - - + return original_value + + # Test helper functions def _setup(): module_override_all({}) - - + + def _teardown(): module_override_all({}) - + # Unit tests begin @with_setup(_setup, _teardown) def test_original_value_without_overrides(): assert_equals(original_value, my_config()) - + @with_setup(_setup, _teardown) def test_original_value_with_overrides(): new_value = 2 module_override(my_config.__name__, new_value) assert_equals(new_value, my_config()) - + @with_setup(_setup, _teardown) def test_original_values_when_others_override(): @@ -63,15 +63,15 @@ def test_original_values_when_others_override(): module_override(my_config.__name__, new_value) assert_equals(new_value, my_config()) assert_equals(original_value, my_config_2()) - - + + @with_setup(_setup, _teardown) def test_resetting_values_when_others_override(): new_value = 2 module_override(my_config.__name__, new_value) assert_equals(new_value, my_config()) assert_equals(original_value, my_config_2()) - + # Reset module_override_all({}) assert_equals(original_value, my_config()) diff --git a/hdijupyterutils/hdijupyterutils/tests/test_events.py b/hdijupyterutils/hdijupyterutils/tests/test_events.py index 614f01d09..1010806d2 100644 --- a/hdijupyterutils/hdijupyterutils/tests/test_events.py +++ b/hdijupyterutils/hdijupyterutils/tests/test_events.py @@ -30,24 +30,26 @@ def test_send_to_handler(): events.send_to_handler(kwargs_list) events.handler.handle_event.assert_called_once_with(expected_kwargs_list) - - + + @with_setup(_setup, _teardown) @raises(AssertionError) def test_send_to_handler_asserts_less_than_12(): - kwargs_list = [(TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp), - (TIMESTAMP, time_stamp)] + kwargs_list = [ + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + (TIMESTAMP, time_stamp), + ] events.send_to_handler(kwargs_list) diff --git a/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py b/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py index b5de9dd4c..831221971 100644 --- a/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py +++ b/hdijupyterutils/hdijupyterutils/tests/test_ipythondisplay.py @@ -3,20 +3,22 @@ from mock import MagicMock import sys + def test_stdout_flush(): ipython_shell = MagicMock() ipython_display = IpythonDisplay() ipython_display._ipython_shell = ipython_shell sys.stdout = MagicMock() - ipython_display.write(u'Testing Stdout Flush รจ') + ipython_display.write("Testing Stdout Flush รจ") assert sys.stdout.flush.call_count == 1 + def test_stderr_flush(): ipython_shell = MagicMock() ipython_display = IpythonDisplay() ipython_display._ipython_shell = ipython_shell sys.stderr = MagicMock() - ipython_display.send_error(u'Testing Stderr Flush รจ') + ipython_display.send_error("Testing Stderr Flush รจ") assert sys.stderr.flush.call_count == 1 diff --git a/hdijupyterutils/hdijupyterutils/tests/test_logger.py b/hdijupyterutils/hdijupyterutils/tests/test_logger.py index 646eb3ae7..92a1785fd 100644 --- a/hdijupyterutils/hdijupyterutils/tests/test_logger.py +++ b/hdijupyterutils/hdijupyterutils/tests/test_logger.py @@ -11,7 +11,7 @@ def get_logging_config(): def test_log_init(): logging_config = get_logging_config() - logger = Log('name', logging_config, 'something') + logger = Log("name", logging_config, "something") assert isinstance(logger.logger, logging.Logger) @@ -22,49 +22,49 @@ def __init__(self): self.level = self.message = None def debug(self, message): - self.level, self.message = 'DEBUG', message - + self.level, self.message = "DEBUG", message + def error(self, message): - self.level, self.message = 'ERROR', message + self.level, self.message = "ERROR", message def info(self, message): - self.level, self.message = 'INFO', message + self.level, self.message = "INFO", message class MockLog(Log): def __init__(self, name): logging_config = get_logging_config() super(MockLog, self).__init__(name, logging_config, name) - + def _getLogger(self): - self.logger = MockLogger() + self.logger = MockLogger() def test_log_returnvalue(): - logger = MockLog('test2') + logger = MockLog("test2") assert isinstance(logger.logger, MockLogger) mock = logger.logger - logger.debug('word1') - assert mock.level == 'DEBUG' - assert_equals(mock.message, 'test2\tword1') - logger.error('word2') - assert mock.level == 'ERROR' - assert mock.message == 'test2\tword2' - logger.info('word3') - assert mock.level == 'INFO' - assert mock.message == 'test2\tword3' + logger.debug("word1") + assert mock.level == "DEBUG" + assert_equals(mock.message, "test2\tword1") + logger.error("word2") + assert mock.level == "ERROR" + assert mock.message == "test2\tword2" + logger.info("word3") + assert mock.level == "INFO" + assert mock.message == "test2\tword3" def test_log_unicode(): - logger = MockLog('test2') + logger = MockLog("test2") assert isinstance(logger.logger, MockLogger) mock = logger.logger - logger.debug(u'word1รจ') - assert mock.level == 'DEBUG' - assert mock.message == u'test2\tword1รจ' - logger.error(u'word2รจ') - assert mock.level == 'ERROR' - assert mock.message == u'test2\tword2รจ' - logger.info(u'word3รจ') - assert mock.level == 'INFO' - assert mock.message == u'test2\tword3รจ' + logger.debug("word1รจ") + assert mock.level == "DEBUG" + assert mock.message == "test2\tword1รจ" + logger.error("word2รจ") + assert mock.level == "ERROR" + assert mock.message == "test2\tword2รจ" + logger.info("word3รจ") + assert mock.level == "INFO" + assert mock.message == "test2\tword3รจ" diff --git a/sparkmagic/sparkmagic/auth/basic.py b/sparkmagic/sparkmagic/auth/basic.py index edecf67d0..381f1c21e 100644 --- a/sparkmagic/sparkmagic/auth/basic.py +++ b/sparkmagic/sparkmagic/auth/basic.py @@ -5,8 +5,10 @@ from requests.auth import HTTPBasicAuth from .customauth import Authenticator + class Basic(HTTPBasicAuth, Authenticator): """Basic Access authenticator for SparkMagic""" + def __init__(self, parsed_attributes=None): """Initializes the Authenticator with the attributes in the attributes parsed from a %spark magic command if applicable, or with default values @@ -18,15 +20,17 @@ def __init__(self, parsed_attributes=None): is created from parsing %spark magic command. """ if parsed_attributes is not None: - if parsed_attributes.user == '' or parsed_attributes.password == '': - new_exc = BadUserDataException("Need to supply username and password arguments for "\ - "Basic Access Authentication. (e.g. -a username -p password).") + if parsed_attributes.user == "" or parsed_attributes.password == "": + new_exc = BadUserDataException( + "Need to supply username and password arguments for " + "Basic Access Authentication. (e.g. -a username -p password)." + ) raise new_exc self.username = parsed_attributes.user self.password = parsed_attributes.password else: - self.username = 'username' - self.password = 'password' + self.username = "username" + self.password = "password" HTTPBasicAuth.__init__(self, self.username, self.password) Authenticator.__init__(self, parsed_attributes) @@ -42,15 +46,11 @@ def get_widgets(self, widget_width): ipywidget_factory = IpyWidgetFactory() self.user_widget = ipywidget_factory.get_text( - description='Username:', - value=self.username, - width=widget_width + description="Username:", value=self.username, width=widget_width ) self.password_widget = ipywidget_factory.get_password( - description='Password:', - value=self.password, - width=widget_width + description="Password:", value=self.password, width=widget_width ) widgets = [self.user_widget, self.password_widget] @@ -65,8 +65,11 @@ def update_with_widget_values(self): def __eq__(self, other): if not isinstance(other, Basic): return False - return self.url == other.url and self.username == other.username and \ - self.password == other.password + return ( + self.url == other.url + and self.username == other.username + and self.password == other.password + ) def __call__(self, request): return HTTPBasicAuth.__call__(self, request) diff --git a/sparkmagic/sparkmagic/auth/customauth.py b/sparkmagic/sparkmagic/auth/customauth.py index 3c8dfb6c7..b84261403 100644 --- a/sparkmagic/sparkmagic/auth/customauth.py +++ b/sparkmagic/sparkmagic/auth/customauth.py @@ -3,6 +3,7 @@ from hdijupyterutils.ipywidgetfactory import IpyWidgetFactory from sparkmagic.utils.constants import WIDGET_WIDTH + class Authenticator(object): """Base Authenticator for all Sparkmagic authentication providers.""" @@ -19,7 +20,7 @@ def __init__(self, parsed_attributes=None): if parsed_attributes is not None: self.url = parsed_attributes.url else: - self.url = 'http://example.com/livy' + self.url = "http://example.com/livy" self.widgets = self.get_widgets(WIDGET_WIDTH) def get_widgets(self, widget_width): @@ -34,9 +35,7 @@ def get_widgets(self, widget_width): ipywidget_factory = IpyWidgetFactory() self.address_widget = ipywidget_factory.get_text( - description='Address:', - value='http://example.com/livy', - width=widget_width + description="Address:", value="http://example.com/livy", width=widget_width ) widgets = [self.address_widget] return widgets diff --git a/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py b/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py index 61e167272..b0557c417 100644 --- a/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py +++ b/sparkmagic/sparkmagic/controllerwidget/abstractmenuwidget.py @@ -5,9 +5,16 @@ class AbstractMenuWidget(Box): - def __init__(self, spark_controller, ipywidget_factory=None, ipython_display=None, - nested_widget_mode=False, testing=False, **kwargs): - kwargs['orientation'] = 'vertical' + def __init__( + self, + spark_controller, + ipywidget_factory=None, + ipython_display=None, + nested_widget_mode=False, + testing=False, + **kwargs + ): + kwargs["orientation"] = "vertical" if not testing: super(AbstractMenuWidget, self).__init__((), **kwargs) diff --git a/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py b/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py index ad460b36c..c795c80ab 100644 --- a/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py +++ b/sparkmagic/sparkmagic/controllerwidget/addendpointwidget.py @@ -8,48 +8,64 @@ class AddEndpointWidget(AbstractMenuWidget): - - def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints, endpoints_dropdown_widget, - refresh_method): + def __init__( + self, + spark_controller, + ipywidget_factory, + ipython_display, + endpoints, + endpoints_dropdown_widget, + refresh_method, + ): # This is nested - super(AddEndpointWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True) + super(AddEndpointWidget, self).__init__( + spark_controller, ipywidget_factory, ipython_display, True + ) self.endpoints = endpoints self.endpoints_dropdown_widget = endpoints_dropdown_widget self.refresh_method = refresh_method - #map auth class path string to the instance of the class. + # map auth class path string to the instance of the class. self.auth_instances = {} for auth in conf.authenticators().values(): - module, class_name = (auth).rsplit('.', 1) + module, class_name = (auth).rsplit(".", 1) events_handler_module = importlib.import_module(module) auth_class = getattr(events_handler_module, class_name) self.auth_instances[auth] = auth_class() self.auth_type = self.ipywidget_factory.get_dropdown( - options=conf.authenticators(), - description=u"Auth type:" + options=conf.authenticators(), description="Auth type:" ) - #combine all authentication instance's widgets into one list to pass to self.children. + # combine all authentication instance's widgets into one list to pass to self.children. self.all_widgets = list() for _class, instance in self.auth_instances.items(): for widget in instance.widgets: - if _class == self.auth_type.value: - widget.layout.display = 'flex' + if _class == self.auth_type.value: + widget.layout.display = "flex" self.auth = instance else: - widget.layout.display = 'none' + widget.layout.display = "none" self.all_widgets.append(widget) # Submit widget self.submit_widget = self.ipywidget_factory.get_submit_button( - description='Add endpoint' + description="Add endpoint" ) self.auth_type.on_trait_change(self._update_auth) - self.children = [self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH), self.auth_type] + self.all_widgets \ - + [self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH), self.submit_widget] + self.children = ( + [ + self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH), + self.auth_type, + ] + + self.all_widgets + + [ + self.ipywidget_factory.get_html(value="
", width=WIDGET_WIDTH), + self.submit_widget, + ] + ) for child in self.children: child.parent_widget = self @@ -77,7 +93,7 @@ def _update_auth(self): Create an instance of the chosen auth type maps to in the config file. """ for widget in self.auth.widgets: - widget.layout.display = 'none' + widget.layout.display = "none" self.auth = self.auth_instances.get(self.auth_type.value) for widget in self.auth.widgets: - widget.layout.display = 'flex' + widget.layout.display = "flex" diff --git a/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py b/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py index af534a7e5..d1d7aaa7d 100644 --- a/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py +++ b/sparkmagic/sparkmagic/controllerwidget/createsessionwidget.py @@ -8,33 +8,46 @@ class CreateSessionWidget(AbstractMenuWidget): - def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints_dropdown_widget, refresh_method): + def __init__( + self, + spark_controller, + ipywidget_factory, + ipython_display, + endpoints_dropdown_widget, + refresh_method, + ): # This is nested - super(CreateSessionWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True) + super(CreateSessionWidget, self).__init__( + spark_controller, ipywidget_factory, ipython_display, True + ) self.refresh_method = refresh_method self.endpoints_dropdown_widget = endpoints_dropdown_widget self.session_widget = self.ipywidget_factory.get_text( - description='Name:', - value='session-name' + description="Name:", value="session-name" ) self.lang_widget = self.ipywidget_factory.get_toggle_buttons( - description='Language:', + description="Language:", options=[LANG_SCALA, LANG_PYTHON], ) self.properties = self.ipywidget_factory.get_text( - description='Properties:', - value=json.dumps(conf.session_configs()) + description="Properties:", value=json.dumps(conf.session_configs()) ) self.submit_widget = self.ipywidget_factory.get_submit_button( - description='Create Session' + description="Create Session" ) - self.children = [self.ipywidget_factory.get_html(value="
", width="600px"), self.endpoints_dropdown_widget, - self.session_widget, self.lang_widget, self.properties, - self.ipywidget_factory.get_html(value="
", width="600px"), self.submit_widget] + self.children = [ + self.ipywidget_factory.get_html(value="
", width="600px"), + self.endpoints_dropdown_widget, + self.session_widget, + self.lang_widget, + self.properties, + self.ipywidget_factory.get_html(value="
", width="600px"), + self.submit_widget, + ] for child in self.children: child.parent_widget = self @@ -43,9 +56,13 @@ def run(self): try: properties_json = self.properties.value if properties_json.strip() != "": - conf.override(conf.session_configs.__name__, json.loads(self.properties.value)) + conf.override( + conf.session_configs.__name__, json.loads(self.properties.value) + ) except ValueError as e: - self.ipython_display.send_error("Session properties must be a valid JSON string. Error:\n{}".format(e)) + self.ipython_display.send_error( + "Session properties must be a valid JSON string. Error:\n{}".format(e) + ) return endpoint = self.endpoints_dropdown_widget.value @@ -57,13 +74,17 @@ def run(self): try: self.spark_controller.add_session(alias, endpoint, skip, properties) except ValueError as e: - self.ipython_display.send_error("""Could not add session with + self.ipython_display.send_error( + """Could not add session with name: {} properties: {} -due to error: '{}'""".format(alias, properties, e)) +due to error: '{}'""".format( + alias, properties, e + ) + ) return self.refresh_method() diff --git a/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py b/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py index 7e3ffd298..9e990ce57 100644 --- a/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py +++ b/sparkmagic/sparkmagic/controllerwidget/magicscontrollerwidget.py @@ -12,11 +12,17 @@ class MagicsControllerWidget(AbstractMenuWidget): - def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints=None): - super(MagicsControllerWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display) + def __init__( + self, spark_controller, ipywidget_factory, ipython_display, endpoints=None + ): + super(MagicsControllerWidget, self).__init__( + spark_controller, ipywidget_factory, ipython_display + ) if endpoints is None: - endpoints = {endpoint.url: endpoint for endpoint in self._get_default_endpoints()} + endpoints = { + endpoint.url: endpoint for endpoint in self._get_default_endpoints() + } self.endpoints = endpoints self._refresh() @@ -29,37 +35,73 @@ def _get_default_endpoints(): default_endpoints = set() for kernel_type in LANGS_SUPPORTED: - endpoint_config = getattr(conf, 'kernel_%s_credentials' % kernel_type)() - if all([p in endpoint_config for p in ["url", "password", "username"]]) and endpoint_config["url"] != "": + endpoint_config = getattr(conf, "kernel_%s_credentials" % kernel_type)() + if ( + all([p in endpoint_config for p in ["url", "password", "username"]]) + and endpoint_config["url"] != "" + ): user = endpoint_config["username"] passwd = endpoint_config["password"] - args = Namespace(user=user, password=passwd, auth=endpoint_config.get("auth", None), url=endpoint_config.get("url", None)) + args = Namespace( + user=user, + password=passwd, + auth=endpoint_config.get("auth", None), + url=endpoint_config.get("url", None), + ) auth_instance = initialize_auth(args) - default_endpoints.add(Endpoint( - auth=auth_instance, - url=endpoint_config["url"], - implicitly_added=True)) + default_endpoints.add( + Endpoint( + auth=auth_instance, + url=endpoint_config["url"], + implicitly_added=True, + ) + ) return default_endpoints def _refresh(self): self.endpoints_dropdown_widget = self.ipywidget_factory.get_dropdown( - description="Endpoint:", - options=self.endpoints + description="Endpoint:", options=self.endpoints ) - self.manage_session = ManageSessionWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display, - self._refresh) - self.create_session = CreateSessionWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display, - self.endpoints_dropdown_widget, self._refresh) - self.add_endpoint = AddEndpointWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display, - self.endpoints, self.endpoints_dropdown_widget, self._refresh) - self.manage_endpoint = ManageEndpointWidget(self.spark_controller, self.ipywidget_factory, self.ipython_display, - self.endpoints, self._refresh) + self.manage_session = ManageSessionWidget( + self.spark_controller, + self.ipywidget_factory, + self.ipython_display, + self._refresh, + ) + self.create_session = CreateSessionWidget( + self.spark_controller, + self.ipywidget_factory, + self.ipython_display, + self.endpoints_dropdown_widget, + self._refresh, + ) + self.add_endpoint = AddEndpointWidget( + self.spark_controller, + self.ipywidget_factory, + self.ipython_display, + self.endpoints, + self.endpoints_dropdown_widget, + self._refresh, + ) + self.manage_endpoint = ManageEndpointWidget( + self.spark_controller, + self.ipywidget_factory, + self.ipython_display, + self.endpoints, + self._refresh, + ) - self.tabs = self.ipywidget_factory.get_tab(children=[self.manage_session, self.create_session, - self.add_endpoint, self.manage_endpoint]) + self.tabs = self.ipywidget_factory.get_tab( + children=[ + self.manage_session, + self.create_session, + self.add_endpoint, + self.manage_endpoint, + ] + ) self.tabs.set_title(0, "Manage Sessions") self.tabs.set_title(1, "Create Session") self.tabs.set_title(2, "Add Endpoint") diff --git a/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py b/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py index c210d6264..1020bfdf0 100644 --- a/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py +++ b/sparkmagic/sparkmagic/controllerwidget/manageendpointwidget.py @@ -6,9 +6,18 @@ class ManageEndpointWidget(AbstractMenuWidget): - def __init__(self, spark_controller, ipywidget_factory, ipython_display, endpoints, refresh_method): + def __init__( + self, + spark_controller, + ipywidget_factory, + ipython_display, + endpoints, + refresh_method, + ): # This is nested - super(ManageEndpointWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True) + super(ManageEndpointWidget, self).__init__( + spark_controller, ipywidget_factory, ipython_display, True + ) self.logger = SparkLog("ManageEndpointWidget") self.endpoints = endpoints @@ -24,7 +33,9 @@ def run(self): def get_existing_endpoint_widgets(self): endpoint_widgets = [] - endpoint_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px")) + endpoint_widgets.append( + self.ipywidget_factory.get_html(value="
", width="600px") + ) if len(self.endpoints) > 0: # Header @@ -40,11 +51,20 @@ def get_existing_endpoint_widgets(self): if not endpoint.implicitly_added: raise else: - self.logger.info("Failed to connect to implicitly-defined endpoint at: %s" % url) - - endpoint_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px")) + self.logger.info( + "Failed to connect to implicitly-defined endpoint at: %s" + % url + ) + + endpoint_widgets.append( + self.ipywidget_factory.get_html(value="
", width="600px") + ) else: - endpoint_widgets.append(self.ipywidget_factory.get_html(value="No endpoints yet.", width="600px")) + endpoint_widgets.append( + self.ipywidget_factory.get_html( + value="No endpoints yet.", width="600px" + ) + ) return endpoint_widgets @@ -63,7 +83,9 @@ def get_endpoint_widget(self, url, endpoint): hbox_outter_children.append(vbox_left) hbox_outter_children.append(cleanup_w) except ValueError as e: - hbox_outter_children.append(self.ipywidget_factory.get_html(value=str(e), width=width)) + hbox_outter_children.append( + self.ipywidget_factory.get_html(value=str(e), width=width) + ) hbox_outter_children.append(self.get_delete_button_endpoint(url, endpoint)) hbox_outter.children = hbox_outter_children @@ -76,7 +98,9 @@ def get_endpoint_left(self, endpoint, url): # 400 px info = self.get_info_endpoint_widget(endpoint, url) delete_session_number = self.get_delete_session_endpoint_widget(url, endpoint) - vbox_left = self.ipywidget_factory.get_vbox(children=[info, delete_session_number], width="400px") + vbox_left = self.ipywidget_factory.get_vbox( + children=[info, delete_session_number], width="400px" + ) return vbox_left def get_cleanup_button_endpoint(self, url, endpoint): @@ -84,7 +108,9 @@ def cleanup_on_click(button): try: self.spark_controller.cleanup_endpoint(endpoint) except ValueError as e: - self.ipython_display.send_error("Could not clean up endpoint due to error: {}".format(e)) + self.ipython_display.send_error( + "Could not clean up endpoint due to error: {}".format(e) + ) return self.ipython_display.writeln("Cleaned up endpoint {}".format(url)) self.refresh_method() @@ -105,7 +131,9 @@ def delete_on_click(button): return delete_w def get_delete_session_endpoint_widget(self, url, endpoint): - session_text = self.ipywidget_factory.get_text(description="Session to delete:", value="0", width="50px") + session_text = self.ipywidget_factory.get_text( + description="Session to delete:", value="0", width="50px" + ) def delete_endpoint(button): try: @@ -120,7 +148,9 @@ def delete_endpoint(button): button = self.ipywidget_factory.get_button(description="Delete") button.on_click(delete_endpoint) - return self.ipywidget_factory.get_hbox(children=[session_text, button], width="152px") + return self.ipywidget_factory.get_hbox( + children=[session_text, button], width="152px" + ) def get_info_endpoint_widget(self, endpoint, url): # 400 px @@ -129,7 +159,9 @@ def get_info_endpoint_widget(self, endpoint, url): info_sessions = self.spark_controller.get_all_sessions_endpoint_info(endpoint) if len(info_sessions) > 0: - text = "{}:
{}".format(url, "* {}".format("
* ".join(info_sessions))) + text = "{}:
{}".format( + url, "* {}".format("
* ".join(info_sessions)) + ) else: text = "No sessions on this endpoint." diff --git a/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py b/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py index 268f049ab..ebd38951d 100644 --- a/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py +++ b/sparkmagic/sparkmagic/controllerwidget/managesessionwidget.py @@ -4,9 +4,13 @@ class ManageSessionWidget(AbstractMenuWidget): - def __init__(self, spark_controller, ipywidget_factory, ipython_display, refresh_method): + def __init__( + self, spark_controller, ipywidget_factory, ipython_display, refresh_method + ): # This is nested - super(ManageSessionWidget, self).__init__(spark_controller, ipywidget_factory, ipython_display, True) + super(ManageSessionWidget, self).__init__( + spark_controller, ipywidget_factory, ipython_display, True + ) self.refresh_method = refresh_method @@ -20,34 +24,55 @@ def run(self): def get_existing_session_widgets(self): session_widgets = [] - session_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px")) + session_widgets.append( + self.ipywidget_factory.get_html(value="
", width="600px") + ) client_dict = self.spark_controller.get_managed_clients() if len(client_dict) > 0: # Header header = self.get_session_widget("Name", "Id", "Kind", "State", False) session_widgets.append(header) - session_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px")) + session_widgets.append( + self.ipywidget_factory.get_html(value="
", width="600px") + ) # Sessions for name, session in client_dict.items(): - session_widgets.append(self.get_session_widget(name, session.id, session.kind, session.status)) - - session_widgets.append(self.ipywidget_factory.get_html(value="
", width="600px")) + session_widgets.append( + self.get_session_widget( + name, session.id, session.kind, session.status + ) + ) + + session_widgets.append( + self.ipywidget_factory.get_html(value="
", width="600px") + ) else: - session_widgets.append(self.ipywidget_factory.get_html(value="No sessions yet.", width="600px")) + session_widgets.append( + self.ipywidget_factory.get_html(value="No sessions yet.", width="600px") + ) return session_widgets def get_session_widget(self, name, session_id, kind, state, button=True): hbox = self.ipywidget_factory.get_hbox() - name_w = self.ipywidget_factory.get_html(value=name, width="200px", padding="4px") - id_w = self.ipywidget_factory.get_html(value=str(session_id), width="100px", padding="4px") - kind_w = self.ipywidget_factory.get_html(value=kind, width="100px", padding="4px") - state_w = self.ipywidget_factory.get_html(value=state, width="100px", padding="4px") + name_w = self.ipywidget_factory.get_html( + value=name, width="200px", padding="4px" + ) + id_w = self.ipywidget_factory.get_html( + value=str(session_id), width="100px", padding="4px" + ) + kind_w = self.ipywidget_factory.get_html( + value=kind, width="100px", padding="4px" + ) + state_w = self.ipywidget_factory.get_html( + value=state, width="100px", padding="4px" + ) if button: + def delete_on_click(button): self.spark_controller.delete_session_by_name(name) self.refresh_method() @@ -55,7 +80,9 @@ def delete_on_click(button): delete_w = self.ipywidget_factory.get_button(description="Delete") delete_w.on_click(delete_on_click) else: - delete_w = self.ipywidget_factory.get_html(value="", width="100px", padding="4px") + delete_w = self.ipywidget_factory.get_html( + value="", width="100px", padding="4px" + ) hbox.children = [name_w, id_w, kind_w, state_w, delete_w] diff --git a/sparkmagic/sparkmagic/kernels/kernelmagics.py b/sparkmagic/sparkmagic/kernels/kernelmagics.py index b22771994..1269947ad 100644 --- a/sparkmagic/sparkmagic/kernels/kernelmagics.py +++ b/sparkmagic/sparkmagic/kernels/kernelmagics.py @@ -15,31 +15,52 @@ import sparkmagic.utils.configuration as conf from sparkmagic.utils.configuration import get_livy_kind from sparkmagic.utils import constants -from sparkmagic.utils.utils import parse_argstring_or_throw, get_coerce_value, initialize_auth, Namespace +from sparkmagic.utils.utils import ( + parse_argstring_or_throw, + get_coerce_value, + initialize_auth, + Namespace, +) from sparkmagic.utils.sparkevents import SparkEvents from sparkmagic.utils.constants import LANGS_SUPPORTED -from sparkmagic.utils.dataframe_parser import cell_contains_dataframe, CellOutputHtmlParser +from sparkmagic.utils.dataframe_parser import ( + cell_contains_dataframe, + CellOutputHtmlParser, +) from sparkmagic.livyclientlib.command import Command from sparkmagic.livyclientlib.endpoint import Endpoint from sparkmagic.magics.sparkmagicsbase import SparkMagicBase, SparkOutputHandler -from sparkmagic.livyclientlib.exceptions import handle_expected_exceptions, wrap_unexpected_exceptions, \ - BadUserDataException +from sparkmagic.livyclientlib.exceptions import ( + handle_expected_exceptions, + wrap_unexpected_exceptions, + BadUserDataException, +) def _event(f): def wrapped(self, *args, **kwargs): guid = self._generate_uuid() - self._spark_events.emit_magic_execution_start_event(f.__name__, get_livy_kind(self.language), guid) + self._spark_events.emit_magic_execution_start_event( + f.__name__, get_livy_kind(self.language), guid + ) try: result = f(self, *args, **kwargs) except Exception as e: - self._spark_events.emit_magic_execution_end_event(f.__name__, get_livy_kind(self.language), guid, - False, e.__class__.__name__, str(e)) + self._spark_events.emit_magic_execution_end_event( + f.__name__, + get_livy_kind(self.language), + guid, + False, + e.__class__.__name__, + str(e), + ) raise else: - self._spark_events.emit_magic_execution_end_event(f.__name__, get_livy_kind(self.language), guid, - True, u"", u"") + self._spark_events.emit_magic_execution_end_event( + f.__name__, get_livy_kind(self.language), guid, True, "", "" + ) return result + wrapped.__name__ = f.__name__ wrapped.__doc__ = f.__doc__ return wrapped @@ -51,15 +72,15 @@ def __init__(self, shell, data=None, spark_events=None): # You must call the parent constructor super(KernelMagics, self).__init__(shell, data) - self.session_name = u"session_name" + self.session_name = "session_name" self.session_started = False # In order to set these following 3 properties, call %%_do_not_call_change_language -l language - self.language = u"" + self.language = "" self.endpoint = None self.fatal_error = False self.allow_retry_fatal = False - self.fatal_error_message = u"" + self.fatal_error_message = "" if spark_events is None: spark_events = SparkEvents() self._spark_events = spark_events @@ -72,7 +93,7 @@ def __init__(self, shell, data=None, spark_events=None): def help(self, line, cell="", local_ns=None): parse_argstring_or_throw(self.help, line) self._assure_cell_body_is_empty(KernelMagics.help.__name__, cell) - help_html = u""" + help_html = """ @@ -165,23 +186,49 @@ def help(self, line, cell="", local_ns=None): self.ipython_display.html(help_html) @cell_magic - def local(self, line, cell=u"", local_ns=None): + def local(self, line, cell="", local_ns=None): # This should not be reachable thanks to UserCodeParser. Registering it here so that it auto-completes with tab. - raise NotImplementedError(u"UserCodeParser should have prevented code execution from reaching here.") + raise NotImplementedError( + "UserCodeParser should have prevented code execution from reaching here." + ) @magic_arguments() - @argument("-i", "--input", type=str, default=None, help="If present, indicated variable will be stored in variable" - " in Spark's context.") - @argument("-t", "--vartype", type=str, default='str', help="Optionally specify the type of input variable. " - "Available: 'str' - string(default) or 'df' - Pandas DataFrame") - @argument("-n", "--varname", type=str, default=None, help="Optionally specify the custom name for the input variable.") - @argument("-m", "--maxrows", type=int, default=2500, help="Maximum number of rows that will be pulled back " - "from the local dataframe") + @argument( + "-i", + "--input", + type=str, + default=None, + help="If present, indicated variable will be stored in variable" + " in Spark's context.", + ) + @argument( + "-t", + "--vartype", + type=str, + default="str", + help="Optionally specify the type of input variable. " + "Available: 'str' - string(default) or 'df' - Pandas DataFrame", + ) + @argument( + "-n", + "--varname", + type=str, + default=None, + help="Optionally specify the custom name for the input variable.", + ) + @argument( + "-m", + "--maxrows", + type=int, + default=2500, + help="Maximum number of rows that will be pulled back " + "from the local dataframe", + ) @cell_magic @needs_local_scope @wrap_unexpected_exceptions @handle_expected_exceptions - def send_to_spark(self, line, cell=u"", local_ns=None): + def send_to_spark(self, line, cell="", local_ns=None): self._assure_cell_body_is_empty(KernelMagics.send_to_spark.__name__, cell) args = parse_argstring_or_throw(self.send_to_spark, line) @@ -189,7 +236,9 @@ def send_to_spark(self, line, cell=u"", local_ns=None): raise BadUserDataException("-i param not provided.") if self._do_not_call_start_session(""): - self.do_send_to_spark(cell, args.input, args.vartype, args.varname, args.maxrows, None) + self.do_send_to_spark( + cell, args.input, args.vartype, args.varname, args.maxrows, None + ) else: return @@ -198,15 +247,21 @@ def send_to_spark(self, line, cell=u"", local_ns=None): @wrap_unexpected_exceptions @handle_expected_exceptions @_event - def info(self, line, cell=u"", local_ns=None): + def info(self, line, cell="", local_ns=None): parse_argstring_or_throw(self.info, line) self._assure_cell_body_is_empty(KernelMagics.info.__name__, cell) if self.session_started: - current_session_id = self.spark_controller.get_session_id_for_client(self.session_name) + current_session_id = self.spark_controller.get_session_id_for_client( + self.session_name + ) else: current_session_id = None - self.ipython_display.html(u"Current session configs: {}
".format(conf.get_session_properties(self.language))) + self.ipython_display.html( + "Current session configs: {}
".format( + conf.get_session_properties(self.language) + ) + ) info_sessions = self.spark_controller.get_all_sessions_endpoint(self.endpoint) self._print_endpoint_info(info_sessions, current_session_id) @@ -223,11 +278,19 @@ def logs(self, line, cell="", local_ns=None): out = self.spark_controller.get_logs() self.ipython_display.write(out) else: - self.ipython_display.write(u"No logs yet.") + self.ipython_display.write("No logs yet.") @magic_arguments() @cell_magic - @argument("-f", "--force", type=bool, default=False, nargs="?", const=True, help="If present, user understands.") + @argument( + "-f", + "--force", + type=bool, + default=False, + nargs="?", + const=True, + help="If present, user understands.", + ) @wrap_unexpected_exceptions @handle_expected_exceptions @_event @@ -235,44 +298,86 @@ def configure(self, line, cell="", local_ns=None): try: dictionary = json.loads(cell) except ValueError: - self.ipython_display.send_error(u"Could not parse JSON object from input '{}'".format(cell)) + self.ipython_display.send_error( + "Could not parse JSON object from input '{}'".format(cell) + ) return args = parse_argstring_or_throw(self.configure, line) if self.session_started: if not args.force: - self.ipython_display.send_error(u"A session has already been started. If you intend to recreate the " - u"session with new configurations, please include the -f argument.") + self.ipython_display.send_error( + "A session has already been started. If you intend to recreate the " + "session with new configurations, please include the -f argument." + ) return else: - self._do_not_call_delete_session(u"") + self._do_not_call_delete_session("") self._override_session_settings(dictionary) - self._do_not_call_start_session(u"") + self._do_not_call_start_session("") else: self._override_session_settings(dictionary) - self.info(u"") + self.info("") @magic_arguments() @cell_magic @needs_local_scope - @argument("-o", "--output", type=str, default=None, help="If present, indicated variable will be stored in variable" - "of this name in user's local context.") - @argument("-m", "--samplemethod", type=str, default=None, help="Sample method for dataframe: either take or sample") - @argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back " - "from the dataframe on the server for storing") - @argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from dataframe") - @argument("-c", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) " - "of the dataframe or not (pass False)") + @argument( + "-o", + "--output", + type=str, + default=None, + help="If present, indicated variable will be stored in variable" + "of this name in user's local context.", + ) + @argument( + "-m", + "--samplemethod", + type=str, + default=None, + help="Sample method for dataframe: either take or sample", + ) + @argument( + "-n", + "--maxrows", + type=int, + default=None, + help="Maximum number of rows that will be pulled back " + "from the dataframe on the server for storing", + ) + @argument( + "-r", + "--samplefraction", + type=float, + default=None, + help="Sample fraction for sampling from dataframe", + ) + @argument( + "-c", + "--coerce", + type=str, + default=None, + help="Whether to automatically coerce the types (default, pass True if being explicit) " + "of the dataframe or not (pass False)", + ) @wrap_unexpected_exceptions @handle_expected_exceptions def spark(self, line, cell="", local_ns=None): - if not self._do_not_call_start_session(u""): + if not self._do_not_call_start_session(""): return args = parse_argstring_or_throw(self.spark, line) coerce = get_coerce_value(args.coerce) - self.execute_spark(cell, args.output, args.samplemethod, args.maxrows, args.samplefraction, None, coerce) + self.execute_spark( + cell, + args.output, + args.samplemethod, + args.maxrows, + args.samplefraction, + None, + coerce, + ) @cell_magic @needs_local_scope @@ -280,34 +385,72 @@ def spark(self, line, cell="", local_ns=None): @handle_expected_exceptions def pretty(self, line, cell="", local_ns=None): """Evaluates a cell and converts dataframes in cell output to HTML tables.""" - if not self._do_not_call_start_session(u""): - return + if not self._do_not_call_start_session(""): + return def pretty_output_handler(out): if cell_contains_dataframe(out): - self.ipython_display.html(CellOutputHtmlParser.to_html(out)) - else: + self.ipython_display.html(CellOutputHtmlParser.to_html(out)) + else: self.ipython_display.write(out) - so = SparkOutputHandler(html=self.ipython_display.html, - text=pretty_output_handler, - default=self.ipython_display.display) - - self.execute_spark(cell, None, None, None, None, None, None, output_handler=so) + so = SparkOutputHandler( + html=self.ipython_display.html, + text=pretty_output_handler, + default=self.ipython_display.display, + ) + self.execute_spark(cell, None, None, None, None, None, None, output_handler=so) @magic_arguments() @cell_magic @needs_local_scope - @argument("-o", "--output", type=str, default=None, help="If present, query will be stored in variable of this " - "name.") - @argument("-q", "--quiet", type=bool, default=False, const=True, nargs="?", help="Return None instead of the dataframe.") - @argument("-m", "--samplemethod", type=str, default=None, help="Sample method for SQL queries: either take or sample") - @argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back " - "from the server for SQL queries") - @argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from SQL queries") - @argument("-c", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) " - "of the dataframe or not (pass False)") + @argument( + "-o", + "--output", + type=str, + default=None, + help="If present, query will be stored in variable of this " "name.", + ) + @argument( + "-q", + "--quiet", + type=bool, + default=False, + const=True, + nargs="?", + help="Return None instead of the dataframe.", + ) + @argument( + "-m", + "--samplemethod", + type=str, + default=None, + help="Sample method for SQL queries: either take or sample", + ) + @argument( + "-n", + "--maxrows", + type=int, + default=None, + help="Maximum number of rows that will be pulled back " + "from the server for SQL queries", + ) + @argument( + "-r", + "--samplefraction", + type=float, + default=None, + help="Sample fraction for sampling from SQL queries", + ) + @argument( + "-c", + "--coerce", + type=str, + default=None, + help="Whether to automatically coerce the types (default, pass True if being explicit) " + "of the dataframe or not (pass False)", + ) @wrap_unexpected_exceptions @handle_expected_exceptions def sql(self, line, cell="", local_ns=None): @@ -318,12 +461,28 @@ def sql(self, line, cell="", local_ns=None): coerce = get_coerce_value(args.coerce) - return self.execute_sqlquery(cell, args.samplemethod, args.maxrows, args.samplefraction, - None, args.output, args.quiet, coerce) + return self.execute_sqlquery( + cell, + args.samplemethod, + args.maxrows, + args.samplefraction, + None, + args.output, + args.quiet, + coerce, + ) @magic_arguments() @cell_magic - @argument("-f", "--force", type=bool, default=False, nargs="?", const=True, help="If present, user understands.") + @argument( + "-f", + "--force", + type=bool, + default=False, + nargs="?", + const=True, + help="If present, user understands.", + ) @wrap_unexpected_exceptions @handle_expected_exceptions @_event @@ -331,18 +490,28 @@ def cleanup(self, line, cell="", local_ns=None): self._assure_cell_body_is_empty(KernelMagics.cleanup.__name__, cell) args = parse_argstring_or_throw(self.cleanup, line) if args.force: - self._do_not_call_delete_session(u"") + self._do_not_call_delete_session("") self.spark_controller.cleanup_endpoint(self.endpoint) else: - self.ipython_display.send_error(u"When you clean up the endpoint, all sessions will be lost, including the " - u"one used for this notebook. Include the -f parameter if that's your " - u"intention.") + self.ipython_display.send_error( + "When you clean up the endpoint, all sessions will be lost, including the " + "one used for this notebook. Include the -f parameter if that's your " + "intention." + ) return @magic_arguments() @cell_magic - @argument("-f", "--force", type=bool, default=False, nargs="?", const=True, help="If present, user understands.") + @argument( + "-f", + "--force", + type=bool, + default=False, + nargs="?", + const=True, + help="If present, user understands.", + ) @argument("-s", "--session", type=int, help="Session id number to delete.") @wrap_unexpected_exceptions @handle_expected_exceptions @@ -353,21 +522,27 @@ def delete(self, line, cell="", local_ns=None): session = args.session if args.session is None: - self.ipython_display.send_error(u'You must provide a session ID (-s argument).') + self.ipython_display.send_error( + "You must provide a session ID (-s argument)." + ) return if args.force: id = self.spark_controller.get_session_id_for_client(self.session_name) if session == id: - self.ipython_display.send_error(u"Cannot delete this kernel's session ({}). Specify a different session," - u" shutdown the kernel to delete this session, or run %cleanup to " - u"delete all sessions for this endpoint.".format(id)) + self.ipython_display.send_error( + "Cannot delete this kernel's session ({}). Specify a different session," + " shutdown the kernel to delete this session, or run %cleanup to " + "delete all sessions for this endpoint.".format(id) + ) return self.spark_controller.delete_session_by_id(self.endpoint, session) else: - self.ipython_display.send_error(u"Include the -f parameter if you understand that all statements executed " - u"in this session will be lost.") + self.ipython_display.send_error( + "Include the -f parameter if you understand that all statements executed " + "in this session will be lost." + ) @cell_magic def _do_not_call_start_session(self, line, cell="", local_ns=None): @@ -385,14 +560,16 @@ def _do_not_call_start_session(self, line, cell="", local_ns=None): properties = conf.get_session_properties(self.language) try: - self.spark_controller.add_session(self.session_name, self.endpoint, skip, properties) + self.spark_controller.add_session( + self.session_name, self.endpoint, skip, properties + ) self.session_started = True self.fatal_error = False - self.fatal_error_message = u"" + self.fatal_error_message = "" except Exception as e: self.fatal_error = True self.fatal_error_message = conf.fatal_error_suggestion().format(e) - self.logger.error(u"Error creating session: {}".format(e)) + self.logger.error("Error creating session: {}".format(e)) self.ipython_display.send_error(self.fatal_error_message) if conf.all_errors_are_fatal(): @@ -427,11 +604,15 @@ def _do_not_call_change_language(self, line, cell="", local_ns=None): language = args.language.lower() if language not in LANGS_SUPPORTED: - self.ipython_display.send_error(u"'{}' language not supported in kernel magics.".format(language)) + self.ipython_display.send_error( + "'{}' language not supported in kernel magics.".format(language) + ) return if self.session_started: - self.ipython_display.send_error(u"Cannot change the language if a session has been started.") + self.ipython_display.send_error( + "Cannot change the language if a session has been started." + ) return self.language = language @@ -439,22 +620,24 @@ def _do_not_call_change_language(self, line, cell="", local_ns=None): @magic_arguments() @line_magic - @argument("-u", "--username", dest='user', type=str, help="Username to use.") + @argument("-u", "--username", dest="user", type=str, help="Username to use.") @argument("-p", "--password", type=str, help="Password to use.") - @argument("-s", "--server", dest='url', type=str, help="Url of server to use.") + @argument("-s", "--server", dest="url", type=str, help="Url of server to use.") @argument("-t", "--auth", type=str, help="Auth type for authentication") @_event def _do_not_call_change_endpoint(self, line, cell="", local_ns=None): args = parse_argstring_or_throw(self._do_not_call_change_endpoint, line) if self.session_started: - error = u"Cannot change the endpoint if a session has been started." + error = "Cannot change the endpoint if a session has been started." raise BadUserDataException(error) auth = initialize_auth(args=args) self.endpoint = Endpoint(args.url, auth) @line_magic def matplot(self, line, cell="", local_ns=None): - session = self.spark_controller.get_session_by_name_or_default(self.session_name) + session = self.spark_controller.get_session_by_name_or_default( + self.session_name + ) command = Command("%matplot " + line) (success, out, mimetype) = command.execute(session) if success: @@ -463,8 +646,13 @@ def matplot(self, line, cell="", local_ns=None): session.ipython_display.send_error(out) def refresh_configuration(self): - credentials = getattr(conf, 'base64_kernel_' + self.language + '_credentials')() - (username, password, auth, url) = (credentials['username'], credentials['password'], credentials['auth'], credentials['url']) + credentials = getattr(conf, "base64_kernel_" + self.language + "_credentials")() + (username, password, auth, url) = ( + credentials["username"], + credentials["password"], + credentials["auth"], + credentials["url"], + ) args = Namespace(auth=auth, user=username, password=password, url=url) auth_instance = initialize_auth(args) self.endpoint = Endpoint(url, auth_instance) @@ -492,8 +680,12 @@ def _generate_uuid(): @staticmethod def _assure_cell_body_is_empty(magic_name, cell): if cell.strip(): - raise BadUserDataException(u"Cell body for %%{} magic must be empty; got '{}' instead" - .format(magic_name, cell.strip())) + raise BadUserDataException( + "Cell body for %%{} magic must be empty; got '{}' instead".format( + magic_name, cell.strip() + ) + ) + def load_ipython_extension(ip): ip.register_magics(KernelMagics) diff --git a/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py b/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py index 99c4176c3..f102a9cad 100644 --- a/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py +++ b/sparkmagic/sparkmagic/kernels/pysparkkernel/__init__.py @@ -1 +1 @@ -__version__ = '0.0.1' \ No newline at end of file +__version__ = "0.0.1" diff --git a/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py b/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py index 7cbb6364e..a40c7e488 100644 --- a/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py +++ b/sparkmagic/sparkmagic/kernels/pysparkkernel/pysparkkernel.py @@ -6,29 +6,32 @@ class PySparkKernel(SparkKernelBase): def __init__(self, **kwargs): - implementation = 'PySpark' - implementation_version = '1.0' + implementation = "PySpark" + implementation_version = "1.0" language = LANG_PYTHON - language_version = '0.1' + language_version = "0.1" language_info = { - 'name': 'pyspark', - 'mimetype': 'text/x-python', - 'codemirror_mode': { - 'name': 'python', - 'version': 3 - }, - 'file_extension': '.py', - 'pygments_lexer': 'python3' + "name": "pyspark", + "mimetype": "text/x-python", + "codemirror_mode": {"name": "python", "version": 3}, + "file_extension": ".py", + "pygments_lexer": "python3", } session_language = LANG_PYTHON - super(PySparkKernel, - self).__init__(implementation, implementation_version, language, - language_version, language_info, session_language, - **kwargs) + super(PySparkKernel, self).__init__( + implementation, + implementation_version, + language, + language_version, + language_info, + session_language, + **kwargs + ) -if __name__ == '__main__': +if __name__ == "__main__": from ipykernel.kernelapp import IPKernelApp + IPKernelApp.launch_instance(kernel_class=PySparkKernel) diff --git a/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py b/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py index 99c4176c3..f102a9cad 100644 --- a/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py +++ b/sparkmagic/sparkmagic/kernels/sparkkernel/__init__.py @@ -1 +1 @@ -__version__ = '0.0.1' \ No newline at end of file +__version__ = "0.0.1" diff --git a/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py b/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py index 365eafe96..22da09b4d 100644 --- a/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py +++ b/sparkmagic/sparkmagic/kernels/sparkkernel/sparkkernel.py @@ -6,26 +6,32 @@ class SparkKernel(SparkKernelBase): def __init__(self, **kwargs): - implementation = 'Spark' - implementation_version = '1.0' + implementation = "Spark" + implementation_version = "1.0" language = LANG_SCALA - language_version = '0.1' + language_version = "0.1" language_info = { - 'name': 'scala', - 'mimetype': 'text/x-scala', - 'codemirror_mode': 'text/x-scala', - 'file_extension': '.sc', - 'pygments_lexer': 'scala' + "name": "scala", + "mimetype": "text/x-scala", + "codemirror_mode": "text/x-scala", + "file_extension": ".sc", + "pygments_lexer": "scala", } session_language = LANG_SCALA - super(SparkKernel, - self).__init__(implementation, implementation_version, language, - language_version, language_info, session_language, - **kwargs) + super(SparkKernel, self).__init__( + implementation, + implementation_version, + language, + language_version, + language_info, + session_language, + **kwargs + ) -if __name__ == '__main__': +if __name__ == "__main__": from ipykernel.kernelapp import IPKernelApp + IPKernelApp.launch_instance(kernel_class=SparkKernel) diff --git a/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py b/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py index 99c4176c3..f102a9cad 100644 --- a/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py +++ b/sparkmagic/sparkmagic/kernels/sparkrkernel/__init__.py @@ -1 +1 @@ -__version__ = '0.0.1' \ No newline at end of file +__version__ = "0.0.1" diff --git a/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py b/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py index c55567e65..85c95e790 100644 --- a/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py +++ b/sparkmagic/sparkmagic/kernels/sparkrkernel/sparkrkernel.py @@ -6,26 +6,32 @@ class SparkRKernel(SparkKernelBase): def __init__(self, **kwargs): - implementation = 'SparkR' - implementation_version = '1.0' + implementation = "SparkR" + implementation_version = "1.0" language = LANG_R - language_version = '0.1' + language_version = "0.1" language_info = { - 'name': 'sparkR', - 'mimetype': 'text/x-rsrc', - 'codemirror_mode': 'text/x-rsrc', - 'file_extension': '.r', - 'pygments_lexer': 'r' + "name": "sparkR", + "mimetype": "text/x-rsrc", + "codemirror_mode": "text/x-rsrc", + "file_extension": ".r", + "pygments_lexer": "r", } session_language = LANG_R - super(SparkRKernel, - self).__init__(implementation, implementation_version, language, - language_version, language_info, session_language, - **kwargs) + super(SparkRKernel, self).__init__( + implementation, + implementation_version, + language, + language_version, + language_info, + session_language, + **kwargs + ) -if __name__ == '__main__': +if __name__ == "__main__": from ipykernel.kernelapp import IPKernelApp + IPKernelApp.launch_instance(kernel_class=SparkRKernel) diff --git a/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py b/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py index 55c80e475..bd45b2efa 100644 --- a/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py +++ b/sparkmagic/sparkmagic/kernels/wrapperkernel/sparkkernelbase.py @@ -3,9 +3,11 @@ try: from asyncio import Future except ImportError: + class Future(object): """A class nothing will use.""" + import requests from ipykernel.ipkernel import IPythonKernel from hdijupyterutils.ipythondisplay import IpythonDisplay @@ -17,8 +19,17 @@ class Future(object): class SparkKernelBase(IPythonKernel): - def __init__(self, implementation, implementation_version, language, language_version, language_info, - session_language, user_code_parser=None, **kwargs): + def __init__( + self, + implementation, + implementation_version, + language, + language_version, + language_info, + session_language, + user_code_parser=None, + **kwargs + ): # Required by Jupyter - Override self.implementation = implementation self.implementation_version = implementation_version @@ -31,7 +42,7 @@ def __init__(self, implementation, implementation_version, language, language_ve super(SparkKernelBase, self).__init__(**kwargs) - self.logger = SparkLog(u"{}_jupyter_kernel".format(self.session_language)) + self.logger = SparkLog("{}_jupyter_kernel".format(self.session_language)) self._fatal_error = None self.ipython_display = IpythonDisplay() @@ -49,12 +60,17 @@ def __init__(self, implementation, implementation_version, language, language_ve if conf.use_auto_viz(): self._register_auto_viz() - def do_execute(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False): + def do_execute( + self, code, silent, store_history=True, user_expressions=None, allow_stdin=False + ): def f(self): if self._fatal_error is not None: return self._repeat_fatal_error() - return self._do_execute(code, silent, store_history, user_expressions, allow_stdin) + return self._do_execute( + code, silent, store_history, user_expressions, allow_stdin + ) + return wrap_unexpected_exceptions(f, self._complete_cell)(self) def do_shutdown(self, restart): @@ -66,54 +82,91 @@ def do_shutdown(self, restart): def _do_execute(self, code, silent, store_history, user_expressions, allow_stdin): code_to_run = self.user_code_parser.get_code_to_run(code) - res = self._execute_cell(code_to_run, silent, store_history, user_expressions, allow_stdin) + res = self._execute_cell( + code_to_run, silent, store_history, user_expressions, allow_stdin + ) return res def _load_magics_extension(self): register_magics_code = "%load_ext sparkmagic.kernels" - self._execute_cell(register_magics_code, True, False, shutdown_if_error=True, - log_if_error="Failed to load the Spark kernels magics library.") + self._execute_cell( + register_magics_code, + True, + False, + shutdown_if_error=True, + log_if_error="Failed to load the Spark kernels magics library.", + ) self.logger.debug("Loaded magics.") def _change_language(self): - register_magics_code = "%%_do_not_call_change_language -l {}\n ".format(self.session_language) - self._execute_cell(register_magics_code, True, False, shutdown_if_error=True, - log_if_error="Failed to change language to {}.".format(self.session_language)) + register_magics_code = "%%_do_not_call_change_language -l {}\n ".format( + self.session_language + ) + self._execute_cell( + register_magics_code, + True, + False, + shutdown_if_error=True, + log_if_error="Failed to change language to {}.".format( + self.session_language + ), + ) self.logger.debug("Changed language.") def _register_auto_viz(self): from sparkmagic.utils.sparkevents import get_spark_events_handler import autovizwidget.utils.configuration as c - + handler = get_spark_events_handler() c.override("events_handler", handler) - + register_auto_viz_code = """from autovizwidget.widget.utils import display_dataframe ip = get_ipython() ip.display_formatter.ipython_display_formatter.for_type_by_name('pandas.core.frame', 'DataFrame', display_dataframe)""" - self._execute_cell(register_auto_viz_code, True, False, shutdown_if_error=True, - log_if_error="Failed to register auto viz for notebook.") + self._execute_cell( + register_auto_viz_code, + True, + False, + shutdown_if_error=True, + log_if_error="Failed to register auto viz for notebook.", + ) self.logger.debug("Registered auto viz.") def _delete_session(self): code = "%%_do_not_call_delete_session\n " self._execute_cell_for_user(code, True, False) - def _execute_cell(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False, - shutdown_if_error=False, log_if_error=None): - reply_content = self._execute_cell_for_user(code, silent, store_history, user_expressions, allow_stdin) - - if shutdown_if_error and reply_content[u"status"] == u"error": - error_from_reply = reply_content[u"evalue"] + def _execute_cell( + self, + code, + silent, + store_history=True, + user_expressions=None, + allow_stdin=False, + shutdown_if_error=False, + log_if_error=None, + ): + reply_content = self._execute_cell_for_user( + code, silent, store_history, user_expressions, allow_stdin + ) + + if shutdown_if_error and reply_content["status"] == "error": + error_from_reply = reply_content["evalue"] if log_if_error is not None: - message = "{}\nException details:\n\t\"{}\"".format(log_if_error, error_from_reply) + message = '{}\nException details:\n\t"{}"'.format( + log_if_error, error_from_reply + ) return self._abort_with_fatal_error(message) return reply_content - def _execute_cell_for_user(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False): - result = super(SparkKernelBase, self).do_execute(code, silent, store_history, user_expressions, allow_stdin) + def _execute_cell_for_user( + self, code, silent, store_history=True, user_expressions=None, allow_stdin=False + ): + result = super(SparkKernelBase, self).do_execute( + code, silent, store_history, user_expressions, allow_stdin + ) if isinstance(result, Future): result = result.result() return result diff --git a/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py b/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py index 1b3f1b4fb..47efbd994 100644 --- a/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py +++ b/sparkmagic/sparkmagic/kernels/wrapperkernel/usercodeparser.py @@ -9,9 +9,18 @@ class UserCodeParser(object): # For example, the %%info magic has no cell body input, i.e. it is incorrect to call # %%info # some_input - _magics_with_no_cell_body = [i.__name__ for i in [KernelMagics.info, KernelMagics.logs, KernelMagics.cleanup, - KernelMagics.delete, KernelMagics.help, KernelMagics.spark, - KernelMagics.send_to_spark]] + _magics_with_no_cell_body = [ + i.__name__ + for i in [ + KernelMagics.info, + KernelMagics.logs, + KernelMagics.cleanup, + KernelMagics.delete, + KernelMagics.help, + KernelMagics.spark, + KernelMagics.send_to_spark, + ] + ] def get_code_to_run(self, code): try: @@ -22,9 +31,9 @@ def get_code_to_run(self, code): if code.startswith("%%local") or code.startswith("%local"): return all_but_first_line elif any(code.startswith("%%" + s) for s in self._magics_with_no_cell_body): - return u"{}\n ".format(code) + return "{}\n ".format(code) elif any(code.startswith("%" + s) for s in self._magics_with_no_cell_body): - return u"%{}\n ".format(code) + return "%{}\n ".format(code) elif code.startswith("%%") or code.startswith("%"): # If they use other line magics: # %autosave @@ -34,4 +43,4 @@ def get_code_to_run(self, code): elif not code: return code else: - return u"%%spark\n{}".format(code) + return "%%spark\n{}".format(code) diff --git a/sparkmagic/sparkmagic/livyclientlib/command.py b/sparkmagic/sparkmagic/livyclientlib/command.py index 309a6158a..3a03266f3 100644 --- a/sparkmagic/sparkmagic/livyclientlib/command.py +++ b/sparkmagic/sparkmagic/livyclientlib/command.py @@ -11,18 +11,27 @@ import sparkmagic.utils.configuration as conf from sparkmagic.utils.sparklogger import SparkLog from sparkmagic.utils.sparkevents import SparkEvents -from sparkmagic.utils.constants import MAGICS_LOGGER_NAME, FINAL_STATEMENT_STATUS, \ - MIMETYPE_IMAGE_PNG, MIMETYPE_TEXT_HTML, MIMETYPE_TEXT_PLAIN, \ - COMMAND_INTERRUPTED_MSG, COMMAND_CANCELLATION_FAILED_MSG -from .exceptions import LivyUnexpectedStatusException, SparkStatementCancelledException, \ - SparkStatementCancellationFailedException +from sparkmagic.utils.constants import ( + MAGICS_LOGGER_NAME, + FINAL_STATEMENT_STATUS, + MIMETYPE_IMAGE_PNG, + MIMETYPE_TEXT_HTML, + MIMETYPE_TEXT_PLAIN, + COMMAND_INTERRUPTED_MSG, + COMMAND_CANCELLATION_FAILED_MSG, +) +from .exceptions import ( + LivyUnexpectedStatusException, + SparkStatementCancelledException, + SparkStatementCancellationFailedException, +) class Command(ObjectWithGuid): def __init__(self, code, spark_events=None): super(Command, self).__init__() self.code = textwrap.dedent(code) - self.logger = SparkLog(u"Command") + self.logger = SparkLog("Command") if spark_events is None: spark_events = SparkEvents() self._spark_events = spark_events @@ -37,68 +46,99 @@ def __ne__(self, other): return not self == other def execute(self, session): - self._spark_events.emit_statement_execution_start_event(session.guid, session.kind, session.id, self.guid) + self._spark_events.emit_statement_execution_start_event( + session.guid, session.kind, session.id, self.guid + ) statement_id = -1 try: session.wait_for_idle() - data = {u"code": self.code} + data = {"code": self.code} response = session.http_client.post_statement(session.id, data) - statement_id = response[u'id'] + statement_id = response["id"] output = self._get_statement_output(session, statement_id) except KeyboardInterrupt as e: - self._spark_events.emit_statement_execution_end_event(session.guid, session.kind, session.id, - self.guid, statement_id, False, e.__class__.__name__, - str(e)) + self._spark_events.emit_statement_execution_end_event( + session.guid, + session.kind, + session.id, + self.guid, + statement_id, + False, + e.__class__.__name__, + str(e), + ) try: if statement_id >= 0: - response = session.http_client.cancel_statement(session.id, statement_id) + response = session.http_client.cancel_statement( + session.id, statement_id + ) session.wait_for_idle() except: - raise SparkStatementCancellationFailedException(COMMAND_CANCELLATION_FAILED_MSG) + raise SparkStatementCancellationFailedException( + COMMAND_CANCELLATION_FAILED_MSG + ) else: raise SparkStatementCancelledException(COMMAND_INTERRUPTED_MSG) except Exception as e: - self._spark_events.emit_statement_execution_end_event(session.guid, session.kind, session.id, - self.guid, statement_id, False, e.__class__.__name__, - str(e)) + self._spark_events.emit_statement_execution_end_event( + session.guid, + session.kind, + session.id, + self.guid, + statement_id, + False, + e.__class__.__name__, + str(e), + ) raise else: - self._spark_events.emit_statement_execution_end_event(session.guid, session.kind, session.id, - self.guid, statement_id, True, "", "") + self._spark_events.emit_statement_execution_end_event( + session.guid, + session.kind, + session.id, + self.guid, + statement_id, + True, + "", + "", + ) return output def _get_statement_output(self, session, statement_id): retries = 1 - progress = FloatProgress(value=0.0, - min=0, - max=1.0, - step=0.01, - description='Progress:', - bar_style='info', - orientation='horizontal', - layout=Layout(width='50%', height='25px') - ) + progress = FloatProgress( + value=0.0, + min=0, + max=1.0, + step=0.01, + description="Progress:", + bar_style="info", + orientation="horizontal", + layout=Layout(width="50%", height="25px"), + ) session.ipython_display.display(progress) while True: statement = session.http_client.get_statement(session.id, statement_id) - status = statement[u"state"].lower() + status = statement["state"].lower() - self.logger.debug(u"Status of statement {} is {}.".format(statement_id, status)) + self.logger.debug( + "Status of statement {} is {}.".format(statement_id, status) + ) if status not in FINAL_STATEMENT_STATUS: - progress.value = statement.get('progress', 0.0) + progress.value = statement.get("progress", 0.0) session.sleep(retries) retries += 1 else: - statement_output = statement[u"output"] + statement_output = statement["output"] progress.close() if statement_output is None: - return (True, u"", MIMETYPE_TEXT_PLAIN) + return (True, "", MIMETYPE_TEXT_PLAIN) - if statement_output[u"status"] == u"ok": - data = statement_output[u"data"] + if statement_output["status"] == "ok": + data = statement_output["data"] if MIMETYPE_IMAGE_PNG in data: image = Image(base64.b64decode(data[MIMETYPE_IMAGE_PNG])) return (True, image, MIMETYPE_IMAGE_PNG) @@ -106,10 +146,17 @@ def _get_statement_output(self, session, statement_id): return (True, data[MIMETYPE_TEXT_HTML], MIMETYPE_TEXT_HTML) else: return (True, data[MIMETYPE_TEXT_PLAIN], MIMETYPE_TEXT_PLAIN) - elif statement_output[u"status"] == u"error": - return (False, - statement_output[u"evalue"] + u"\n" + u"".join(statement_output[u"traceback"]), - MIMETYPE_TEXT_PLAIN) + elif statement_output["status"] == "error": + return ( + False, + statement_output["evalue"] + + "\n" + + "".join(statement_output["traceback"]), + MIMETYPE_TEXT_PLAIN, + ) else: - raise LivyUnexpectedStatusException(u"Unknown output status from Livy: '{}'" - .format(statement_output[u"status"])) + raise LivyUnexpectedStatusException( + "Unknown output status from Livy: '{}'".format( + statement_output["status"] + ) + ) diff --git a/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py b/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py index 8ed854744..f4de99b81 100644 --- a/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py +++ b/sparkmagic/sparkmagic/livyclientlib/configurableretrypolicy.py @@ -18,7 +18,9 @@ def __init__(self, retry_seconds_to_sleep_list, max_retries): if len(retry_seconds_to_sleep_list) == 0: retry_seconds_to_sleep_list = [5] elif not all(n > 0 for n in retry_seconds_to_sleep_list): - raise BadUserConfigurationException(u"All items in the list in your config need to be positive for configurable retry policy") + raise BadUserConfigurationException( + "All items in the list in your config need to be positive for configurable retry policy" + ) self.retry_seconds_to_sleep_list = retry_seconds_to_sleep_list self._max_index = len(self.retry_seconds_to_sleep_list) - 1 diff --git a/sparkmagic/sparkmagic/livyclientlib/endpoint.py b/sparkmagic/sparkmagic/livyclientlib/endpoint.py index d49744382..59b32d36b 100644 --- a/sparkmagic/sparkmagic/livyclientlib/endpoint.py +++ b/sparkmagic/sparkmagic/livyclientlib/endpoint.py @@ -1,11 +1,12 @@ from .exceptions import BadUserDataException + class Endpoint(object): def __init__(self, url, auth, implicitly_added=False): if not url: - raise BadUserDataException(u"URL must not be empty") + raise BadUserDataException("URL must not be empty") - self.url = url.rstrip(u"/") + self.url = url.rstrip("/") self.auth = auth # implicitly_added is set to True only if the endpoint wasn't configured manually by the user through # a widget, but was instead implicitly defined as an endpoint to a wrapper kernel in the configuration @@ -24,4 +25,4 @@ def __ne__(self, other): return not self == other def __str__(self): - return u"Endpoint({})".format(self.url) + return "Endpoint({})".format(self.url) diff --git a/sparkmagic/sparkmagic/livyclientlib/exceptions.py b/sparkmagic/sparkmagic/livyclientlib/exceptions.py index 5c3bb0e5f..b048aa49f 100644 --- a/sparkmagic/sparkmagic/livyclientlib/exceptions.py +++ b/sparkmagic/sparkmagic/livyclientlib/exceptions.py @@ -24,6 +24,7 @@ class HttpClientException(LivyClientLibException): class LivyClientTimeoutException(LivyClientLibException): """An exception for timeouts while interacting with Livy.""" + class DataFrameParseException(LivyClientLibException): """An internal error which suggests a bad implementation of dataframe parsing from JSON -- if we get a JSON parsing error when parsing the results from the Livy server, this exception @@ -88,7 +89,17 @@ class SparkStatementCancellationFailedException(KeyboardInterrupt): # == DECORATORS FOR EXCEPTION HANDLING == -EXPECTED_EXCEPTIONS = [BadUserConfigurationException, BadUserDataException, LivyUnexpectedStatusException, SqlContextNotFoundException, HttpClientException, LivyClientTimeoutException, SessionManagementException, SparkStatementException] +EXPECTED_EXCEPTIONS = [ + BadUserConfigurationException, + BadUserDataException, + LivyUnexpectedStatusException, + SqlContextNotFoundException, + HttpClientException, + LivyClientTimeoutException, + SessionManagementException, + SparkStatementException, +] + def handle_expected_exceptions(f): """A decorator that handles expected exceptions. Self can be any object with @@ -98,6 +109,7 @@ def handle_expected_exceptions(f): def fn(self, ...): etc...""" from sparkmagic.utils import configuration as conf + exceptions_to_handle = tuple(EXPECTED_EXCEPTIONS) # Notice that we're NOT handling e.DataFrameParseException here. That's because DataFrameParseException @@ -114,6 +126,7 @@ def wrapped(self, *args, **kwargs): return None else: return out + wrapped.__name__ = f.__name__ wrapped.__doc__ = f.__doc__ return wrapped @@ -128,10 +141,15 @@ def wrap_unexpected_exceptions(f, execute_if_error=None): Usage: @wrap_unexpected_exceptions def fn(self, ...): - ..etc """ + ..etc""" from sparkmagic.utils import configuration as conf + def handle_exception(self, e): - self.logger.error(u"ENCOUNTERED AN INTERNAL ERROR: {}\n\tTraceback:\n{}".format(e, traceback.format_exc())) + self.logger.error( + "ENCOUNTERED AN INTERNAL ERROR: {}\n\tTraceback:\n{}".format( + e, traceback.format_exc() + ) + ) self.ipython_display.send_error(INTERNAL_ERROR_MSG.format(e)) return None if execute_if_error is None else execute_if_error() @@ -144,6 +162,7 @@ def wrapped(self, *args, **kwargs): return handle_exception(self, err) else: return out + wrapped.__name__ = f.__name__ wrapped.__doc__ = f.__doc__ return wrapped diff --git a/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py b/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py index 176ec7b79..fa54f154e 100644 --- a/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py +++ b/sparkmagic/sparkmagic/livyclientlib/livyreliablehttpclient.py @@ -12,22 +12,29 @@ class LivyReliableHttpClient(object): """A Livy-specific Http client which wraps the normal ReliableHttpClient. Propagates HttpClientExceptions up.""" + def __init__(self, http_client, endpoint): self.endpoint = endpoint self._http_client = http_client @staticmethod def from_endpoint(endpoint): - headers = {"Content-Type": "application/json" } + headers = {"Content-Type": "application/json"} headers.update(conf.custom_headers()) retry_policy = LivyReliableHttpClient._get_retry_policy() - return LivyReliableHttpClient(ReliableHttpClient(endpoint, headers, retry_policy), endpoint) + return LivyReliableHttpClient( + ReliableHttpClient(endpoint, headers, retry_policy), endpoint + ) def post_statement(self, session_id, data): - return self._http_client.post(self._statements_url(session_id), [201], data).json() + return self._http_client.post( + self._statements_url(session_id), [201], data + ).json() def get_statement(self, session_id, statement_id): - return self._http_client.get(self._statement_url(session_id, statement_id), [200]).json() + return self._http_client.get( + self._statement_url(session_id, statement_id), [200] + ).json() def get_sessions(self): return self._http_client.get("/sessions", [200]).json() @@ -42,13 +49,17 @@ def delete_session(self, session_id): self._http_client.delete(self._session_url(session_id), [200, 404]) def get_all_session_logs(self, session_id): - return self._http_client.get(self._session_url(session_id) + "/log?from=0", [200]).json() + return self._http_client.get( + self._session_url(session_id) + "/log?from=0", [200] + ).json() def get_headers(self): return self._http_client.get_headers() def cancel_statement(self, session_id, statement_id): - return self._http_client.post("{}/cancel".format(self._statement_url(session_id, statement_id)), [200], {}).json() + return self._http_client.post( + "{}/cancel".format(self._statement_url(session_id, statement_id)), [200], {} + ).json() @staticmethod def _session_url(session_id): @@ -69,6 +80,11 @@ def _get_retry_policy(): if policy == LINEAR_RETRY: return LinearRetryPolicy(seconds_to_sleep=5, max_retries=5) elif policy == CONFIGURABLE_RETRY: - return ConfigurableRetryPolicy(retry_seconds_to_sleep_list=conf.retry_seconds_to_sleep_list(), max_retries=conf.configurable_retry_policy_max_retries()) + return ConfigurableRetryPolicy( + retry_seconds_to_sleep_list=conf.retry_seconds_to_sleep_list(), + max_retries=conf.configurable_retry_policy_max_retries(), + ) else: - raise BadUserConfigurationException(u"Retry policy '{}' not supported".format(policy)) + raise BadUserConfigurationException( + "Retry policy '{}' not supported".format(policy) + ) diff --git a/sparkmagic/sparkmagic/livyclientlib/livysession.py b/sparkmagic/sparkmagic/livyclientlib/livysession.py index 81512d856..58643034c 100644 --- a/sparkmagic/sparkmagic/livyclientlib/livysession.py +++ b/sparkmagic/sparkmagic/livyclientlib/livysession.py @@ -11,8 +11,12 @@ from sparkmagic.utils.utils import get_sessions_info_html from .configurableretrypolicy import ConfigurableRetryPolicy from .command import Command -from .exceptions import LivyClientTimeoutException, \ - LivyUnexpectedStatusException, BadUserDataException, SqlContextNotFoundException +from .exceptions import ( + LivyClientTimeoutException, + LivyUnexpectedStatusException, + BadUserDataException, + SqlContextNotFoundException, +) class _HeartbeatThread(threading.Thread): @@ -41,10 +45,12 @@ def __init__(self, livy_session, refresh_seconds, retry_seconds, run_at_most=Non def run(self): loop_counter = 0 if self.livy_session is None: - print(u"Will not start heartbeat thread because self.livy_session is None") + print("Will not start heartbeat thread because self.livy_session is None") return - self.livy_session.logger.info(u'Starting heartbeat for session {}'.format(self.livy_session.id)) + self.livy_session.logger.info( + "Starting heartbeat for session {}".format(self.livy_session.id) + ) while self.livy_session is not None and loop_counter < self.run_at_most: loop_counter += 1 @@ -58,23 +64,31 @@ def run(self): # the "exception" function in the SparkLog class then you could just make this # self.livy_session.logger.exception("some useful message") and it'll print # out the stack trace too. - self.livy_session.logger.error(u'{}'.format(e)) + self.livy_session.logger.error("{}".format(e)) sleep(sleep_time) - def stop(self): if self.livy_session is not None: - self.livy_session.logger.info(u'Stopping heartbeat for session {}'.format(self.livy_session.id)) + self.livy_session.logger.info( + "Stopping heartbeat for session {}".format(self.livy_session.id) + ) self.livy_session = None self.join() class LivySession(ObjectWithGuid): - def __init__(self, http_client, properties, ipython_display, - session_id=-1, spark_events=None, - heartbeat_timeout=0, heartbeat_thread=None): + def __init__( + self, + http_client, + properties, + ipython_display, + session_id=-1, + spark_events=None, + heartbeat_timeout=0, + heartbeat_thread=None, + ): super(LivySession, self).__init__() assert constants.LIVY_KIND_PARAM in list(properties.keys()) kind = properties[constants.LIVY_KIND_PARAM] @@ -95,28 +109,33 @@ def __init__(self, http_client, properties, ipython_display, spark_events = SparkEvents() self._spark_events = spark_events - self._policy = ConfigurableRetryPolicy(retry_seconds_to_sleep_list=[0.2, 0.5, 0.5, 1, 1, 2], max_retries=5000) + self._policy = ConfigurableRetryPolicy( + retry_seconds_to_sleep_list=[0.2, 0.5, 0.5, 1, 1, 2], max_retries=5000 + ) wait_for_idle_timeout_seconds = conf.wait_for_idle_timeout_seconds() assert wait_for_idle_timeout_seconds > 0 - self.logger = SparkLog(u"LivySession") + self.logger = SparkLog("LivySession") kind = kind.lower() if kind not in constants.SESSION_KINDS_SUPPORTED: - raise BadUserDataException(u"Session of kind '{}' not supported. Session must be of kinds {}." - .format(kind, ", ".join(constants.SESSION_KINDS_SUPPORTED))) + raise BadUserDataException( + "Session of kind '{}' not supported. Session must be of kinds {}.".format( + kind, ", ".join(constants.SESSION_KINDS_SUPPORTED) + ) + ) self._app_id = None self._user = None - self._logs = u"" + self._logs = "" self._http_client = http_client self._wait_for_idle_timeout_seconds = wait_for_idle_timeout_seconds self._printed_resource_warning = False self.kind = kind self.id = session_id - self.session_info = u"" + self.session_info = "" self._heartbeat_thread = None if session_id == -1: @@ -126,8 +145,14 @@ def __init__(self, http_client, properties, ipython_display, self._start_heartbeat_thread() def __str__(self): - return u"Session id: {}\tYARN id: {}\tKind: {}\tState: {}\n\tSpark UI: {}\n\tDriver Log: {}"\ - .format(self.id, self.get_app_id(), self.kind, self.status, self.get_spark_ui_url(), self.get_driver_log_url()) + return "Session id: {}\tYARN id: {}\tKind: {}\tState: {}\n\tSpark UI: {}\n\tDriver Log: {}".format( + self.id, + self.get_app_id(), + self.kind, + self.status, + self.get_spark_ui_url(), + self.get_driver_log_url(), + ) def start(self): """Start the session against actual livy server.""" @@ -136,10 +161,10 @@ def start(self): try: r = self._http_client.post_session(self.properties) - self.id = r[u"id"] - self.status = str(r[u"state"]) + self.id = r["id"] + self.status = str(r["state"]) - self.ipython_display.writeln(u"Starting Spark application") + self.ipython_display.writeln("Starting Spark application") # Start heartbeat thread to keep Livy interactive session alive. self._start_heartbeat_thread() @@ -148,8 +173,11 @@ def start(self): try: self.wait_for_idle(conf.livy_session_startup_timeout_seconds()) except LivyClientTimeoutException: - raise LivyClientTimeoutException(u"Session {} did not start up in {} seconds." - .format(self.id, conf.livy_session_startup_timeout_seconds())) + raise LivyClientTimeoutException( + "Session {} did not start up in {} seconds.".format( + self.id, conf.livy_session_startup_timeout_seconds() + ) + ) html = get_sessions_info_html([self], self.id) self.ipython_display.html(html) @@ -158,26 +186,41 @@ def start(self): (success, out, mimetype) = command.execute(self) if success: - self.ipython_display.writeln(u"SparkSession available as 'spark'.") + self.ipython_display.writeln("SparkSession available as 'spark'.") self.sql_context_variable_name = "spark" else: command = Command("sqlContext") (success, out, mimetype) = command.execute(self) if success: - self.ipython_display.writeln(u"SparkContext available as 'sc'.") - if ("hive" in out.lower()): - self.ipython_display.writeln(u"HiveContext available as 'sqlContext'.") + self.ipython_display.writeln("SparkContext available as 'sc'.") + if "hive" in out.lower(): + self.ipython_display.writeln( + "HiveContext available as 'sqlContext'." + ) else: - self.ipython_display.writeln(u"SqlContext available as 'sqlContext'.") + self.ipython_display.writeln( + "SqlContext available as 'sqlContext'." + ) self.sql_context_variable_name = "sqlContext" else: - raise SqlContextNotFoundException(u"Neither SparkSession nor HiveContext/SqlContext is available.") + raise SqlContextNotFoundException( + "Neither SparkSession nor HiveContext/SqlContext is available." + ) except Exception as e: - self._spark_events.emit_session_creation_end_event(self.guid, self.kind, self.id, self.status, - False, e.__class__.__name__, str(e)) + self._spark_events.emit_session_creation_end_event( + self.guid, + self.kind, + self.id, + self.status, + False, + e.__class__.__name__, + str(e), + ) raise else: - self._spark_events.emit_session_creation_end_event(self.guid, self.kind, self.id, self.status, True, "", "") + self._spark_events.emit_session_creation_end_event( + self.guid, self.kind, self.id, self.status, True, "", "" + ) def get_app_id(self): if self._app_id is None: @@ -195,7 +238,7 @@ def get_driver_log_url(self): return self.get_app_info_member("driverLogUrl") def get_logs(self): - log_array = self._http_client.get_all_session_logs(self.id)[u'log'] + log_array = self._http_client.get_all_session_logs(self.id)["log"] self._logs = "\n".join(log_array) return self._logs @@ -225,10 +268,12 @@ def is_posted(self): def delete(self): session_id = self.id - self._spark_events.emit_session_deletion_start_event(self.guid, self.kind, session_id, self.status) + self._spark_events.emit_session_deletion_start_event( + self.guid, self.kind, session_id, self.status + ) try: - self.logger.debug(u"Deleting session '{}'".format(session_id)) + self.logger.debug("Deleting session '{}'".format(session_id)) if self.status != constants.NOT_STARTED_SESSION_STATUS: self._http_client.delete_session(session_id) @@ -236,15 +281,27 @@ def delete(self): self.status = constants.DEAD_SESSION_STATUS self.id = -1 else: - self.ipython_display.send_error(u"Cannot delete session {} that is in state '{}'." - .format(session_id, self.status)) + self.ipython_display.send_error( + "Cannot delete session {} that is in state '{}'.".format( + session_id, self.status + ) + ) except Exception as e: - self._spark_events.emit_session_deletion_end_event(self.guid, self.kind, session_id, self.status, False, - e.__class__.__name__, str(e)) + self._spark_events.emit_session_deletion_end_event( + self.guid, + self.kind, + session_id, + self.status, + False, + e.__class__.__name__, + str(e), + ) raise else: - self._spark_events.emit_session_deletion_end_event(self.guid, self.kind, session_id, self.status, True, "", "") + self._spark_events.emit_session_deletion_end_event( + self.guid, self.kind, session_id, self.status, True, "", "" + ) def wait_for_idle(self, seconds_to_wait=None): """Wait for session to go to idle status. Sleep meanwhile. @@ -262,29 +319,41 @@ def wait_for_idle(self, seconds_to_wait=None): return if self.status in constants.FINAL_STATUS: - error = u"Session {} unexpectedly reached final status '{}'."\ - .format(self.id, self.status) + error = "Session {} unexpectedly reached final status '{}'.".format( + self.id, self.status + ) self.logger.error(error) - raise LivyUnexpectedStatusException(u'{} See logs:\n{}'.format(error, self.get_logs())) + raise LivyUnexpectedStatusException( + "{} See logs:\n{}".format(error, self.get_logs()) + ) if seconds_to_wait <= 0.0: - error = u"Session {} did not reach idle status in time. Current status is {}."\ - .format(self.id, self.status) + error = "Session {} did not reach idle status in time. Current status is {}.".format( + self.id, self.status + ) self.logger.error(error) raise LivyClientTimeoutException(error) - if constants.YARN_RESOURCE_LIMIT_MSG in self.session_info and \ - not self._printed_resource_warning: - self.ipython_display.send_error(constants.RESOURCE_LIMIT_WARNING\ - .format(conf.resource_limit_mitigation_suggestion())) + if ( + constants.YARN_RESOURCE_LIMIT_MSG in self.session_info + and not self._printed_resource_warning + ): + self.ipython_display.send_error( + constants.RESOURCE_LIMIT_WARNING.format( + conf.resource_limit_mitigation_suggestion() + ) + ) self._printed_resource_warning = True start_time = time() sleep_time = self._policy.seconds_to_sleep(retries) retries += 1 - self.logger.debug(u"Session {} in state {}. Sleeping {} seconds." - .format(self.id, self.status, sleep_time)) + self.logger.debug( + "Session {} in state {}. Sleeping {} seconds.".format( + self.id, self.status, sleep_time + ) + ) sleep(sleep_time) seconds_to_wait -= time() - start_time @@ -295,14 +364,16 @@ def sleep(self, retries): # Only the status will be returned as the return value. def refresh_status_and_info(self): response = self._http_client.get_session(self.id) - status = response[u'state'] - log_array = response[u'log'] + status = response["state"] + log_array = response["log"] if status in constants.POSSIBLE_SESSION_STATUS: self.status = status - self.session_info = u"\n".join(log_array) + self.session_info = "\n".join(log_array) else: - raise LivyUnexpectedStatusException(u"Status '{}' not supported by session.".format(status)) + raise LivyUnexpectedStatusException( + "Status '{}' not supported by session.".format(status) + ) def _start_heartbeat_thread(self): if self._should_heartbeat and self._heartbeat_thread is None: @@ -310,7 +381,9 @@ def _start_heartbeat_thread(self): retry_seconds = conf.heartbeat_retry_seconds() if self._user_passed_heartbeat_thread is None: - self._heartbeat_thread = _HeartbeatThread(self, refresh_seconds, retry_seconds) + self._heartbeat_thread = _HeartbeatThread( + self, refresh_seconds, retry_seconds + ) else: self._heartbeat_thread = self._user_passed_heartbeat_thread @@ -323,15 +396,22 @@ def _stop_heartbeat_thread(self): self._heartbeat_thread = None def get_row_html(self, current_session_id): - return u"""""".format( - self.id, self.get_app_id(), self.kind, self.status, - self.get_html_link(u'Link', self.get_spark_ui_url()), self.get_html_link(u'Link', self.get_driver_log_url()), - self.get_user(), u"" if current_session_id is None or current_session_id != self.id else u"\u2714" + return """""".format( + self.id, + self.get_app_id(), + self.kind, + self.status, + self.get_html_link("Link", self.get_spark_ui_url()), + self.get_html_link("Link", self.get_driver_log_url()), + self.get_user(), + "" + if current_session_id is None or current_session_id != self.id + else "\u2714", ) @staticmethod def get_html_link(text, url): if url is not None: - return u"""{0}""".format(text, url) + return """{0}""".format(text, url) else: - return u"" + return "" diff --git a/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py b/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py index 75786347b..29f3906da 100644 --- a/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py +++ b/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py @@ -7,6 +7,7 @@ from sparkmagic.utils.sparklogger import SparkLog from .exceptions import HttpClientException + class ReliableHttpClient(object): """Http client that is reliable in its requests. Uses requests library.""" @@ -16,49 +17,71 @@ def __init__(self, endpoint, headers, retry_policy): self._retry_policy = retry_policy self._auth = self._endpoint.auth self._session = requests.Session() - self.logger = SparkLog(u"ReliableHttpClient") + self.logger = SparkLog("ReliableHttpClient") self.verify_ssl = not conf.ignore_ssl_errors() if not self.verify_ssl: - self.logger.debug(u"ATTENTION: Will ignore SSL errors. This might render you vulnerable to attacks.") + self.logger.debug( + "ATTENTION: Will ignore SSL errors. This might render you vulnerable to attacks." + ) requests.packages.urllib3.disable_warnings() def get_headers(self): return self._headers def compose_url(self, relative_url): - r_u = "/{}".format(relative_url.rstrip(u"/").lstrip(u"/")) + r_u = "/{}".format(relative_url.rstrip("/").lstrip("/")) return self._endpoint.url + r_u def get(self, relative_url, accepted_status_codes): """Sends a get request. Returns a response.""" - return self._send_request(relative_url, accepted_status_codes, self._session.get) + return self._send_request( + relative_url, accepted_status_codes, self._session.get + ) def post(self, relative_url, accepted_status_codes, data): """Sends a post request. Returns a response.""" - return self._send_request(relative_url, accepted_status_codes, self._session.post, data) + return self._send_request( + relative_url, accepted_status_codes, self._session.post, data + ) def delete(self, relative_url, accepted_status_codes): """Sends a delete request. Returns a response.""" - return self._send_request(relative_url, accepted_status_codes, self._session.delete) + return self._send_request( + relative_url, accepted_status_codes, self._session.delete + ) def _send_request(self, relative_url, accepted_status_codes, function, data=None): - return self._send_request_helper(self.compose_url(relative_url), accepted_status_codes, function, data, 0) + return self._send_request_helper( + self.compose_url(relative_url), accepted_status_codes, function, data, 0 + ) - def _send_request_helper(self, url, accepted_status_codes, function, data, retry_count): + def _send_request_helper( + self, url, accepted_status_codes, function, data, retry_count + ): while True: try: if data is None: - r = function(url, headers=self._headers, auth=self._auth, verify=self.verify_ssl) + r = function( + url, + headers=self._headers, + auth=self._auth, + verify=self.verify_ssl, + ) else: - r = function(url, headers=self._headers, auth=self._auth, - data=json.dumps(data), verify=self.verify_ssl) + r = function( + url, + headers=self._headers, + auth=self._auth, + data=json.dumps(data), + verify=self.verify_ssl, + ) except requests.exceptions.RequestException as e: error = True r = None status = None text = None - self.logger.error(u"Request to '{}' failed with '{}'".format(url, e)) + self.logger.error("Request to '{}' failed with '{}'".format(url, e)) else: error = False status = r.status_code @@ -71,8 +94,13 @@ def _send_request_helper(self, url, accepted_status_codes, function, data, retry continue if error: - raise HttpClientException(u"Error sending http request and maximum retry encountered.") + raise HttpClientException( + "Error sending http request and maximum retry encountered." + ) else: - raise HttpClientException(u"Invalid status code '{}' from {} with error payload: {}" - .format(status, url, text)) + raise HttpClientException( + "Invalid status code '{}' from {} with error payload: {}".format( + status, url, text + ) + ) return r diff --git a/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py b/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py index b145780bf..3e7f26c04 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py +++ b/sparkmagic/sparkmagic/livyclientlib/sendpandasdftosparkcommand.py @@ -9,10 +9,11 @@ import pandas as pd + class SendPandasDfToSparkCommand(SendToSparkCommand): # convert unicode to utf8 or pyspark will mark data as corrupted(and deserialize incorrectly) - _python_decode = u''' + _python_decode = """ import sys import json @@ -37,19 +38,25 @@ def _byteify(data, ignore_dicts = False): for key, value in data.iteritems() } return data - ''' - - def __init__(self, input_variable_name, input_variable_value, output_variable_name, max_rows): - super(SendPandasDfToSparkCommand, self).__init__(input_variable_name, input_variable_value, output_variable_name) + """ + + def __init__( + self, input_variable_name, input_variable_value, output_variable_name, max_rows + ): + super(SendPandasDfToSparkCommand, self).__init__( + input_variable_name, input_variable_value, output_variable_name + ) self.max_rows = max_rows def _scala_command(self, input_variable_name, pandas_df, output_variable_name): self._assert_input_is_pandas_dataframe(input_variable_name, pandas_df) pandas_json = self._get_dataframe_as_json(pandas_df) - scala_code = u''' + scala_code = ''' val rdd_json_array = spark.sparkContext.makeRDD("""{}""" :: Nil) - val {} = spark.read.json(rdd_json_array)'''.format(pandas_json, output_variable_name) + val {} = spark.read.json(rdd_json_array)'''.format( + pandas_json, output_variable_name + ) return Command(scala_code) @@ -60,10 +67,12 @@ def _pyspark_command(self, input_variable_name, pandas_df, output_variable_name) pandas_json = self._get_dataframe_as_json(pandas_df) - pyspark_code += u''' + pyspark_code += """ json_array = json_loads_byteified('{}') rdd_json_array = spark.sparkContext.parallelize(json_array) - {} = spark.read.json(rdd_json_array)'''.format(pandas_json, output_variable_name) + {} = spark.read.json(rdd_json_array)""".format( + pandas_json, output_variable_name + ) return Command(pyspark_code) @@ -71,20 +80,28 @@ def _r_command(self, input_variable_name, pandas_df, output_variable_name): self._assert_input_is_pandas_dataframe(input_variable_name, pandas_df) pandas_json = self._get_dataframe_as_json(pandas_df) - r_code = u''' + r_code = """ fileConn<-file("temporary_pandas_df_sparkmagics.txt") writeLines('{}', fileConn) close(fileConn) {} <- read.json("temporary_pandas_df_sparkmagics.txt") {}.persist() - file.remove("temporary_pandas_df_sparkmagics.txt")'''.format(pandas_json, output_variable_name, output_variable_name) + file.remove("temporary_pandas_df_sparkmagics.txt")""".format( + pandas_json, output_variable_name, output_variable_name + ) return Command(r_code) def _get_dataframe_as_json(self, pandas_df): - return pandas_df.head(self.max_rows).to_json(orient=u'records') + return pandas_df.head(self.max_rows).to_json(orient="records") - def _assert_input_is_pandas_dataframe(self, input_variable_name, input_variable_value): + def _assert_input_is_pandas_dataframe( + self, input_variable_name, input_variable_value + ): if not isinstance(input_variable_value, pd.DataFrame): wrong_type = input_variable_value.__class__.__name__ - raise BadUserDataException(u'{} is not a Pandas DataFrame! Got {} instead.'.format(input_variable_name, wrong_type)) + raise BadUserDataException( + "{} is not a Pandas DataFrame! Got {} instead.".format( + input_variable_name, wrong_type + ) + ) diff --git a/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py b/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py index b2b861a83..29ea3dee8 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py +++ b/sparkmagic/sparkmagic/livyclientlib/sendstringtosparkcommand.py @@ -5,25 +5,43 @@ from sparkmagic.livyclientlib.command import Command from sparkmagic.livyclientlib.exceptions import BadUserDataException -class SendStringToSparkCommand(SendToSparkCommand): - def _scala_command(self, input_variable_name, input_variable_value, output_variable_name): +class SendStringToSparkCommand(SendToSparkCommand): + def _scala_command( + self, input_variable_name, input_variable_value, output_variable_name + ): self._assert_input_is_string_type(input_variable_name, input_variable_value) - scala_code = u'var {} = """{}"""'.format(output_variable_name, input_variable_value) + scala_code = 'var {} = """{}"""'.format( + output_variable_name, input_variable_value + ) return Command(scala_code) - def _pyspark_command(self, input_variable_name, input_variable_value, output_variable_name): + def _pyspark_command( + self, input_variable_name, input_variable_value, output_variable_name + ): self._assert_input_is_string_type(input_variable_name, input_variable_value) - pyspark_code = u'{} = {}'.format(output_variable_name, repr(input_variable_value)) + pyspark_code = "{} = {}".format( + output_variable_name, repr(input_variable_value) + ) return Command(pyspark_code) - def _r_command(self, input_variable_name, input_variable_value, output_variable_name): + def _r_command( + self, input_variable_name, input_variable_value, output_variable_name + ): self._assert_input_is_string_type(input_variable_name, input_variable_value) - escaped_input_variable_value = input_variable_value.replace(u'\\', u'\\\\').replace(u'"',u'\\"') - r_code = u'''assign("{}","{}")'''.format(output_variable_name, escaped_input_variable_value) + escaped_input_variable_value = input_variable_value.replace( + "\\", "\\\\" + ).replace('"', '\\"') + r_code = """assign("{}","{}")""".format( + output_variable_name, escaped_input_variable_value + ) return Command(r_code) def _assert_input_is_string_type(self, input_variable_name, input_variable_value): if not isinstance(input_variable_value, str): wrong_type = input_variable_value.__class__.__name__ - raise BadUserDataException(u'{} is not a str or bytes! Got {} instead'.format(input_variable_name, wrong_type)) + raise BadUserDataException( + "{} is not a str or bytes! Got {} instead".format( + input_variable_name, wrong_type + ) + ) diff --git a/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py b/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py index 5aa45f5a5..cc2cb729c 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py +++ b/sparkmagic/sparkmagic/livyclientlib/sendtosparkcommand.py @@ -7,8 +7,15 @@ from abc import abstractmethod + class SendToSparkCommand(Command): - def __init__(self, input_variable_name, input_variable_value, output_variable_name, spark_events=None): + def __init__( + self, + input_variable_name, + input_variable_value, + output_variable_name, + spark_events=None, + ): super(SendToSparkCommand, self).__init__("", spark_events) self.input_variable_name = input_variable_name self.input_variable_value = input_variable_value @@ -16,29 +23,48 @@ def __init__(self, input_variable_name, input_variable_value, output_variable_na def execute(self, session): try: - command = self.to_command(session.kind, self.input_variable_name, self.input_variable_value, self.output_variable_name) + command = self.to_command( + session.kind, + self.input_variable_name, + self.input_variable_value, + self.output_variable_name, + ) return command.execute(session) except Exception as e: raise e - def to_command(self, kind, input_variable_name, input_variable_value, output_variable_name): + def to_command( + self, kind, input_variable_name, input_variable_value, output_variable_name + ): if kind == constants.SESSION_KIND_PYSPARK: - return self._pyspark_command(input_variable_name, input_variable_value, output_variable_name) + return self._pyspark_command( + input_variable_name, input_variable_value, output_variable_name + ) elif kind == constants.SESSION_KIND_SPARK: - return self._scala_command(input_variable_name, input_variable_value, output_variable_name) + return self._scala_command( + input_variable_name, input_variable_value, output_variable_name + ) elif kind == constants.SESSION_KIND_SPARKR: - return self._r_command(input_variable_name, input_variable_value, output_variable_name) + return self._r_command( + input_variable_name, input_variable_value, output_variable_name + ) else: - raise BadUserDataException(u"Kind '{}' is not supported.".format(kind)) + raise BadUserDataException("Kind '{}' is not supported.".format(kind)) @abstractmethod - def _scala_command(self, input_variable_name, input_variable_value, output_variable_name): - raise NotImplementedError #override and provide proper implementation in supertype! + def _scala_command( + self, input_variable_name, input_variable_value, output_variable_name + ): + raise NotImplementedError # override and provide proper implementation in supertype! @abstractmethod - def _pyspark_command(self, input_variable_name, input_variable_value, output_variable_name): - raise NotImplementedError #override and provide proper implementation in supertype! + def _pyspark_command( + self, input_variable_name, input_variable_value, output_variable_name + ): + raise NotImplementedError # override and provide proper implementation in supertype! @abstractmethod - def _r_command(self, input_variable_name, input_variable_value, output_variable_name): - raise NotImplementedError #override and provide proper implementation in supertype! + def _r_command( + self, input_variable_name, input_variable_value, output_variable_name + ): + raise NotImplementedError # override and provide proper implementation in supertype! diff --git a/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py b/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py index 9ba850dd0..2004587c4 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py +++ b/sparkmagic/sparkmagic/livyclientlib/sessionmanager.py @@ -9,7 +9,7 @@ class SessionManager(object): def __init__(self, ipython_display): - self.logger = SparkLog(u"SessionManager") + self.logger = SparkLog("SessionManager") self.ipython_display = ipython_display self._sessions = dict() @@ -24,12 +24,17 @@ def get_sessions_list(self): return list(self._sessions.keys()) def get_sessions_info(self): - return [u"Name: {}\t{}".format(k, str(self._sessions[k])) for k in list(self._sessions.keys())] + return [ + "Name: {}\t{}".format(k, str(self._sessions[k])) + for k in list(self._sessions.keys()) + ] def add_session(self, name, session): if name in self._sessions: - raise SessionManagementException(u"Session with name '{}' already exists. Please delete the session" - u" first if you intend to replace it.".format(name)) + raise SessionManagementException( + "Session with name '{}' already exists. Please delete the session" + " first if you intend to replace it.".format(name) + ) self._sessions[name] = session @@ -39,16 +44,24 @@ def get_any_session(self): key = self.get_sessions_list()[0] return self._sessions[key] elif number_of_sessions == 0: - raise SessionManagementException(u"You need to have at least 1 client created to execute commands.") + raise SessionManagementException( + "You need to have at least 1 client created to execute commands." + ) else: - raise SessionManagementException(u"Please specify the client to use. Possible sessions are {}".format( - self.get_sessions_list())) - + raise SessionManagementException( + "Please specify the client to use. Possible sessions are {}".format( + self.get_sessions_list() + ) + ) + def get_session(self, name): if name in self._sessions: return self._sessions[name] - raise SessionManagementException(u"Could not find '{}' session in list of saved sessions. Possible sessions are {}".format( - name, self.get_sessions_list())) + raise SessionManagementException( + "Could not find '{}' session in list of saved sessions. Possible sessions are {}".format( + name, self.get_sessions_list() + ) + ) def get_session_id_for_client(self, name): if name in self.get_sessions_list(): @@ -63,7 +76,7 @@ def get_session_name_by_id_endpoint(self, id, endpoint): def delete_client(self, name): self._remove_session(name) - + def clean_up_all(self): for name in self.get_sessions_list(): self._remove_session(name) @@ -73,19 +86,26 @@ def _remove_session(self, name): self._sessions[name].delete() del self._sessions[name] else: - raise SessionManagementException(u"Could not find '{}' session in list of saved sessions. Possible sessions are {}" - .format(name, self.get_sessions_list())) + raise SessionManagementException( + "Could not find '{}' session in list of saved sessions. Possible sessions are {}".format( + name, self.get_sessions_list() + ) + ) def _register_cleanup_on_exit(self): """ Stop the livy sessions before python process exits for any reason (if enabled in conf) """ if conf.cleanup_all_sessions_on_exit(): + def cleanup_spark_sessions(): try: self.clean_up_all() except Exception as e: - self.logger.error(u"Error cleaning up sessions on exit: {}".format(e)) + self.logger.error( + "Error cleaning up sessions on exit: {}".format(e) + ) pass + atexit.register(cleanup_spark_sessions) - self.ipython_display.writeln(u"Cleaning up livy sessions on exit is enabled") + self.ipython_display.writeln("Cleaning up livy sessions on exit is enabled") diff --git a/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py b/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py index da8ec6298..419ee08f0 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py +++ b/sparkmagic/sparkmagic/livyclientlib/sparkcontroller.py @@ -10,9 +10,8 @@ class SparkController(object): - def __init__(self, ipython_display): - self.logger = SparkLog(u"SparkController") + self.logger = SparkLog("SparkController") self.ipython_display = ipython_display self.session_manager = SessionManager(ipython_display) # this is to reuse the already created http clients @@ -45,11 +44,22 @@ def run_sqlquery(self, sqlquery, client_name=None): def get_all_sessions_endpoint(self, endpoint): http_client = self._http_client(endpoint) - sessions = http_client.get_sessions()[u"sessions"] - supported_sessions = filter(lambda s: (s[constants.LIVY_KIND_PARAM] in constants.SESSION_KINDS_SUPPORTED), sessions) - session_list = [self._livy_session(http_client, {constants.LIVY_KIND_PARAM: s[constants.LIVY_KIND_PARAM]}, - self.ipython_display, s[u"id"]) - for s in supported_sessions] + sessions = http_client.get_sessions()["sessions"] + supported_sessions = filter( + lambda s: ( + s[constants.LIVY_KIND_PARAM] in constants.SESSION_KINDS_SUPPORTED + ), + sessions, + ) + session_list = [ + self._livy_session( + http_client, + {constants.LIVY_KIND_PARAM: s[constants.LIVY_KIND_PARAM]}, + self.ipython_display, + s["id"], + ) + for s in supported_sessions + ] for s in session_list: s.refresh_status_and_info() return session_list @@ -69,7 +79,9 @@ def delete_session_by_name(self, name): self.session_manager.delete_client(name) def delete_session_by_id(self, endpoint, session_id): - name = self.session_manager.get_session_name_by_id_endpoint(session_id, endpoint) + name = self.session_manager.get_session_name_by_id_endpoint( + session_id, endpoint + ) if name in self.session_manager.get_sessions_list(): self.delete_session_by_name(name) @@ -77,13 +89,21 @@ def delete_session_by_id(self, endpoint, session_id): http_client = self._http_client(endpoint) response = http_client.get_session(session_id) http_client = self._http_client(endpoint) - session = self._livy_session(http_client, {constants.LIVY_KIND_PARAM: response[constants.LIVY_KIND_PARAM]}, - self.ipython_display, session_id) + session = self._livy_session( + http_client, + {constants.LIVY_KIND_PARAM: response[constants.LIVY_KIND_PARAM]}, + self.ipython_display, + session_id, + ) session.delete() def add_session(self, name, endpoint, skip_if_exists, properties): if skip_if_exists and (name in self.session_manager.get_sessions_list()): - self.logger.debug(u"Skipping {} because it already exists in list of sessions.".format(name)) + self.logger.debug( + "Skipping {} because it already exists in list of sessions.".format( + name + ) + ) return http_client = self._http_client(endpoint) session = self._livy_session(http_client, properties, self.ipython_display) @@ -97,7 +117,6 @@ def add_session(self, name, endpoint, skip_if_exists, properties): else: self.session_manager.add_session(name, session) - def get_session_id_for_client(self, name): return self.session_manager.get_session_id_for_client(name) @@ -117,12 +136,18 @@ def get_managed_clients(self): return self.session_manager.sessions @staticmethod - def _livy_session(http_client, properties, ipython_display, - session_id=-1): - return LivySession(http_client, properties, ipython_display, - session_id, heartbeat_timeout=conf.livy_server_heartbeat_timeout_seconds()) + def _livy_session(http_client, properties, ipython_display, session_id=-1): + return LivySession( + http_client, + properties, + ipython_display, + session_id, + heartbeat_timeout=conf.livy_server_heartbeat_timeout_seconds(), + ) def _http_client(self, endpoint): if endpoint not in self._http_clients: - self._http_clients[endpoint] = LivyReliableHttpClient.from_endpoint(endpoint) + self._http_clients[endpoint] = LivyReliableHttpClient.from_endpoint( + endpoint + ) return self._http_clients[endpoint] diff --git a/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py b/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py index 2b5919fc3..6bcc836d0 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py +++ b/sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py @@ -5,13 +5,25 @@ import sparkmagic.utils.configuration as conf from sparkmagic.utils.sparkevents import SparkEvents from sparkmagic.livyclientlib.command import Command -from sparkmagic.livyclientlib.exceptions import DataFrameParseException, BadUserDataException +from sparkmagic.livyclientlib.exceptions import ( + DataFrameParseException, + BadUserDataException, +) import sparkmagic.utils.constants as constants import ast + class SparkStoreCommand(Command): - def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None, coerce=None): + def __init__( + self, + output_var, + samplemethod=None, + maxrows=None, + samplefraction=None, + spark_events=None, + coerce=None, + ): super(SparkStoreCommand, self).__init__("", spark_events) if samplemethod is None: @@ -21,12 +33,16 @@ def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=N if samplefraction is None: samplefraction = conf.default_samplefraction() - if samplemethod not in {u'take', u'sample'}: - raise BadUserDataException(u'samplemethod (-m) must be one of (take, sample)') + if samplemethod not in {"take", "sample"}: + raise BadUserDataException( + "samplemethod (-m) must be one of (take, sample)" + ) if not isinstance(maxrows, int): - raise BadUserDataException(u'maxrows (-n) must be an integer') + raise BadUserDataException("maxrows (-n) must be an integer") if not 0.0 <= samplefraction <= 1.0: - raise BadUserDataException(u'samplefraction (-r) must be a float between 0.0 and 1.0') + raise BadUserDataException( + "samplefraction (-r) must be a float between 0.0 and 1.0" + ) self.samplemethod = samplemethod self.maxrows = maxrows @@ -37,7 +53,6 @@ def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=N self._spark_events = spark_events self._coerce = coerce - def execute(self, session): try: command = self.to_command(session.kind, self.output_var) @@ -50,7 +65,6 @@ def execute(self, session): else: return result - def to_command(self, kind, spark_context_variable_name): if kind == constants.SESSION_KIND_PYSPARK: return self._pyspark_command(spark_context_variable_name) @@ -59,64 +73,63 @@ def to_command(self, kind, spark_context_variable_name): elif kind == constants.SESSION_KIND_SPARKR: return self._r_command(spark_context_variable_name) else: - raise BadUserDataException(u"Kind '{}' is not supported.".format(kind)) - + raise BadUserDataException("Kind '{}' is not supported.".format(kind)) def _pyspark_command(self, spark_context_variable_name): # use_unicode=False means the result will be UTF-8-encoded bytes, so we # set it to False for Python 2. - command = u'{}.toJSON(use_unicode=(sys.version_info.major > 2))'.format( - spark_context_variable_name) - if self.samplemethod == u'sample': - command = u'{}.sample(False, {})'.format(command, self.samplefraction) + command = "{}.toJSON(use_unicode=(sys.version_info.major > 2))".format( + spark_context_variable_name + ) + if self.samplemethod == "sample": + command = "{}.sample(False, {})".format(command, self.samplefraction) if self.maxrows >= 0: - command = u'{}.take({})'.format(command, self.maxrows) + command = "{}.take({})".format(command, self.maxrows) else: - command = u'{}.collect()'.format(command) + command = "{}.collect()".format(command) # Unicode support has improved in Python 3 so we don't need to encode. print_command = constants.LONG_RANDOM_VARIABLE_NAME - command = u'import sys\nfor {} in {}: print({})'.format( - constants.LONG_RANDOM_VARIABLE_NAME, - command, - print_command) + command = "import sys\nfor {} in {}: print({})".format( + constants.LONG_RANDOM_VARIABLE_NAME, command, print_command + ) return Command(command) - def _scala_command(self, spark_context_variable_name): - command = u'{}.toJSON'.format(spark_context_variable_name) - if self.samplemethod == u'sample': - command = u'{}.sample(false, {})'.format(command, self.samplefraction) + command = "{}.toJSON".format(spark_context_variable_name) + if self.samplemethod == "sample": + command = "{}.sample(false, {})".format(command, self.samplefraction) if self.maxrows >= 0: - command = u'{}.take({})'.format(command, self.maxrows) + command = "{}.take({})".format(command, self.maxrows) else: - command = u'{}.collect'.format(command) - return Command(u'{}.foreach(println)'.format(command)) - + command = "{}.collect".format(command) + return Command("{}.foreach(println)".format(command)) def _r_command(self, spark_context_variable_name): command = spark_context_variable_name - if self.samplemethod == u'sample': - command = u'sample({}, FALSE, {})'.format(command, - self.samplefraction) + if self.samplemethod == "sample": + command = "sample({}, FALSE, {})".format(command, self.samplefraction) if self.maxrows >= 0: - command = u'take({},{})'.format(command, self.maxrows) + command = "take({},{})".format(command, self.maxrows) else: - command = u'collect({})'.format(command) - command = u'jsonlite::toJSON({})'.format(command) - command = u'for ({} in ({})) {{cat({})}}'.format(constants.LONG_RANDOM_VARIABLE_NAME, - command, - constants.LONG_RANDOM_VARIABLE_NAME) + command = "collect({})".format(command) + command = "jsonlite::toJSON({})".format(command) + command = "for ({} in ({})) {{cat({})}}".format( + constants.LONG_RANDOM_VARIABLE_NAME, + command, + constants.LONG_RANDOM_VARIABLE_NAME, + ) return Command(command) - # Used only for unit testing def __eq__(self, other): - return self.code == other.code and \ - self.samplemethod == other.samplemethod and \ - self.maxrows == other.maxrows and \ - self.samplefraction == other.samplefraction and \ - self.output_var == other.output_var and \ - self._coerce == other._coerce + return ( + self.code == other.code + and self.samplemethod == other.samplemethod + and self.maxrows == other.maxrows + and self.samplefraction == other.samplefraction + and self.output_var == other.output_var + and self._coerce == other._coerce + ) def __ne__(self, other): return not (self == other) diff --git a/sparkmagic/sparkmagic/livyclientlib/sqlquery.py b/sparkmagic/sparkmagic/livyclientlib/sqlquery.py index 6f460810b..3e821c85c 100644 --- a/sparkmagic/sparkmagic/livyclientlib/sqlquery.py +++ b/sparkmagic/sparkmagic/livyclientlib/sqlquery.py @@ -1,6 +1,9 @@ from hdijupyterutils.guid import ObjectWithGuid -from sparkmagic.utils.utils import coerce_pandas_df_to_numeric_datetime, records_to_dataframe +from sparkmagic.utils.utils import ( + coerce_pandas_df_to_numeric_datetime, + records_to_dataframe, +) import sparkmagic.utils.configuration as conf import sparkmagic.utils.constants as constants from sparkmagic.utils.sparkevents import SparkEvents @@ -9,9 +12,17 @@ class SQLQuery(ObjectWithGuid): - def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None, coerce=None): + def __init__( + self, + query, + samplemethod=None, + maxrows=None, + samplefraction=None, + spark_events=None, + coerce=None, + ): super(SQLQuery, self).__init__() - + if samplemethod is None: samplemethod = conf.default_samplemethod() if maxrows is None: @@ -19,12 +30,16 @@ def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None, if samplefraction is None: samplefraction = conf.default_samplefraction() - if samplemethod not in {u'take', u'sample'}: - raise BadUserDataException(u'samplemethod (-m) must be one of (take, sample)') + if samplemethod not in {"take", "sample"}: + raise BadUserDataException( + "samplemethod (-m) must be one of (take, sample)" + ) if not isinstance(maxrows, int): - raise BadUserDataException(u'maxrows (-n) must be an integer') + raise BadUserDataException("maxrows (-n) must be an integer") if not 0.0 <= samplefraction <= 1.0: - raise BadUserDataException(u'samplefraction (-r) must be a float between 0.0 and 1.0') + raise BadUserDataException( + "samplefraction (-r) must be a float between 0.0 and 1.0" + ) self.query = query self.samplemethod = samplemethod @@ -43,12 +58,19 @@ def to_command(self, kind, sql_context_variable_name): elif kind == constants.SESSION_KIND_SPARKR: return self._r_command(sql_context_variable_name) else: - raise BadUserDataException(u"Kind '{}' is not supported.".format(kind)) + raise BadUserDataException("Kind '{}' is not supported.".format(kind)) def execute(self, session): - self._spark_events.emit_sql_execution_start_event(session.guid, session.kind, session.id, self.guid, - self.samplemethod, self.maxrows, self.samplefraction) - command_guid = '' + self._spark_events.emit_sql_execution_start_event( + session.guid, + session.kind, + session.id, + self.guid, + self.samplemethod, + self.maxrows, + self.samplefraction, + ) + command_guid = "" try: command = self.to_command(session.kind, session.sql_context_variable_name) command_guid = command.guid @@ -57,66 +79,89 @@ def execute(self, session): raise BadUserDataException(records_text) result = records_to_dataframe(records_text, session.kind, self._coerce) except Exception as e: - self._spark_events.emit_sql_execution_end_event(session.guid, session.kind, session.id, self.guid, - command_guid, False, e.__class__.__name__, str(e)) + self._spark_events.emit_sql_execution_end_event( + session.guid, + session.kind, + session.id, + self.guid, + command_guid, + False, + e.__class__.__name__, + str(e), + ) raise else: - self._spark_events.emit_sql_execution_end_event(session.guid, session.kind, session.id, self.guid, - command_guid, True, "", "") + self._spark_events.emit_sql_execution_end_event( + session.guid, + session.kind, + session.id, + self.guid, + command_guid, + True, + "", + "", + ) return result - def _pyspark_command(self, sql_context_variable_name): # use_unicode=False means the result will be UTF-8-encoded bytes, so we # set it to False for Python 2. - command = u'{}.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2))'.format( - sql_context_variable_name, self.query) - if self.samplemethod == u'sample': - command = u'{}.sample(False, {})'.format(command, self.samplefraction) + command = '{}.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2))'.format( + sql_context_variable_name, self.query + ) + if self.samplemethod == "sample": + command = "{}.sample(False, {})".format(command, self.samplefraction) if self.maxrows >= 0: - command = u'{}.take({})'.format(command, self.maxrows) + command = "{}.take({})".format(command, self.maxrows) else: - command = u'{}.collect()'.format(command) + command = "{}.collect()".format(command) print_command = constants.LONG_RANDOM_VARIABLE_NAME - command = u'import sys\nfor {} in {}: print({})'.format( - constants.LONG_RANDOM_VARIABLE_NAME, - command, - print_command) + command = "import sys\nfor {} in {}: print({})".format( + constants.LONG_RANDOM_VARIABLE_NAME, command, print_command + ) return Command(command) def _scala_command(self, sql_context_variable_name): - command = u'{}.sql("""{}""").toJSON'.format(sql_context_variable_name, self.query) - if self.samplemethod == u'sample': - command = u'{}.sample(false, {})'.format(command, self.samplefraction) + command = '{}.sql("""{}""").toJSON'.format( + sql_context_variable_name, self.query + ) + if self.samplemethod == "sample": + command = "{}.sample(false, {})".format(command, self.samplefraction) if self.maxrows >= 0: - command = u'{}.take({})'.format(command, self.maxrows) + command = "{}.take({})".format(command, self.maxrows) else: - command = u'{}.collect'.format(command) - return Command(u'{}.foreach(println)'.format(command)) + command = "{}.collect".format(command) + return Command("{}.foreach(println)".format(command)) def _r_command(self, sql_context_variable_name): - if sql_context_variable_name == 'spark': - command = u'sql("{}")'.format(self.query) + if sql_context_variable_name == "spark": + command = 'sql("{}")'.format(self.query) else: - command = u'sql({}, "{}")'.format(sql_context_variable_name, self.query) - if self.samplemethod == u'sample': - command = u'sample({}, FALSE, {})'.format(command, self.samplefraction) + command = 'sql({}, "{}")'.format(sql_context_variable_name, self.query) + if self.samplemethod == "sample": + command = "sample({}, FALSE, {})".format(command, self.samplefraction) if self.maxrows >= 0: - command = u'take({},{})'.format(command, self.maxrows) + command = "take({},{})".format(command, self.maxrows) else: - command = u'collect({})'.format(command) - command = u'jsonlite:::toJSON({})'.format(command) - command = u'for ({} in ({})) {{cat({})}}'.format(constants.LONG_RANDOM_VARIABLE_NAME, command, constants.LONG_RANDOM_VARIABLE_NAME) + command = "collect({})".format(command) + command = "jsonlite:::toJSON({})".format(command) + command = "for ({} in ({})) {{cat({})}}".format( + constants.LONG_RANDOM_VARIABLE_NAME, + command, + constants.LONG_RANDOM_VARIABLE_NAME, + ) return Command(command) # Used only for unit testing def __eq__(self, other): - return self.query == other.query and \ - self.samplemethod == other.samplemethod and \ - self.maxrows == other.maxrows and \ - self.samplefraction == other.samplefraction and \ - self._coerce == other._coerce + return ( + self.query == other.query + and self.samplemethod == other.samplemethod + and self.maxrows == other.maxrows + and self.samplefraction == other.samplefraction + and self._coerce == other._coerce + ) def __ne__(self, other): return not (self == other) diff --git a/sparkmagic/sparkmagic/magics/remotesparkmagics.py b/sparkmagic/sparkmagic/magics/remotesparkmagics.py index 8c6fffcb8..eb3a1b3b7 100644 --- a/sparkmagic/sparkmagic/magics/remotesparkmagics.py +++ b/sparkmagic/sparkmagic/magics/remotesparkmagics.py @@ -12,8 +12,18 @@ from hdijupyterutils.ipywidgetfactory import IpyWidgetFactory import sparkmagic.utils.configuration as conf -from sparkmagic.utils.utils import parse_argstring_or_throw, get_coerce_value, initialize_auth -from sparkmagic.utils.constants import CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, LANG_PYTHON, LANG_R, LANG_SCALA +from sparkmagic.utils.utils import ( + parse_argstring_or_throw, + get_coerce_value, + initialize_auth, +) +from sparkmagic.utils.constants import ( + CONTEXT_NAME_SPARK, + CONTEXT_NAME_SQL, + LANG_PYTHON, + LANG_R, + LANG_SCALA, +) from sparkmagic.controllerwidget.magicscontrollerwidget import MagicsControllerWidget from sparkmagic.livyclientlib.endpoint import Endpoint from sparkmagic.magics.sparkmagicsbase import SparkMagicBase @@ -28,7 +38,9 @@ def __init__(self, shell, data=None, widget=None): self.endpoints = {} if widget is None: - widget = MagicsControllerWidget(self.spark_controller, IpyWidgetFactory(), self.ipython_display) + widget = MagicsControllerWidget( + self.spark_controller, IpyWidgetFactory(), self.ipython_display + ) self.manage_widget = widget @line_magic @@ -38,80 +50,163 @@ def manage_spark(self, line, local_ns=None): return self.manage_widget @magic_arguments() - @argument("-c", "--context", type=str, default=CONTEXT_NAME_SPARK, - help="Context to use: '{}' for spark and '{}' for sql queries. " - "Default is '{}'.".format(CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, CONTEXT_NAME_SPARK)) - @argument("-s", "--session", type=str, default=None, help="The name of the Livy session to use.") - @argument("-o", "--output", type=str, default=None, help="If present, output when using SQL " - "queries will be stored in this variable.") - @argument("-q", "--quiet", type=bool, default=False, nargs="?", const=True, help="Do not display visualizations" - " on SQL queries") - @argument("-m", "--samplemethod", type=str, default=None, help="Sample method for SQL queries: either take or sample") - @argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back " - "from the server for SQL queries") - @argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from SQL queries") + @argument( + "-c", + "--context", + type=str, + default=CONTEXT_NAME_SPARK, + help="Context to use: '{}' for spark and '{}' for sql queries. " + "Default is '{}'.".format( + CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, CONTEXT_NAME_SPARK + ), + ) + @argument( + "-s", + "--session", + type=str, + default=None, + help="The name of the Livy session to use.", + ) + @argument( + "-o", + "--output", + type=str, + default=None, + help="If present, output when using SQL " + "queries will be stored in this variable.", + ) + @argument( + "-q", + "--quiet", + type=bool, + default=False, + nargs="?", + const=True, + help="Do not display visualizations" " on SQL queries", + ) + @argument( + "-m", + "--samplemethod", + type=str, + default=None, + help="Sample method for SQL queries: either take or sample", + ) + @argument( + "-n", + "--maxrows", + type=int, + default=None, + help="Maximum number of rows that will be pulled back " + "from the server for SQL queries", + ) + @argument( + "-r", + "--samplefraction", + type=float, + default=None, + help="Sample fraction for sampling from SQL queries", + ) @argument("-u", "--url", type=str, default=None, help="URL for Livy endpoint") - @argument("-a", "--user", dest='user', type=str, default="", help="Username for HTTP access to Livy endpoint") - @argument("-p", "--password", type=str, default="", help="Password for HTTP access to Livy endpoint") - @argument("-t", "--auth", type=str, default=None, help="Auth type for HTTP access to Livy endpoint. [Kerberos, None, Basic]") - @argument("-l", "--language", type=str, default=None, - help="Language for Livy session; one of {}".format(', '.join([LANG_PYTHON, LANG_SCALA, LANG_R]))) + @argument( + "-a", + "--user", + dest="user", + type=str, + default="", + help="Username for HTTP access to Livy endpoint", + ) + @argument( + "-p", + "--password", + type=str, + default="", + help="Password for HTTP access to Livy endpoint", + ) + @argument( + "-t", + "--auth", + type=str, + default=None, + help="Auth type for HTTP access to Livy endpoint. [Kerberos, None, Basic]", + ) + @argument( + "-l", + "--language", + type=str, + default=None, + help="Language for Livy session; one of {}".format( + ", ".join([LANG_PYTHON, LANG_SCALA, LANG_R]) + ), + ) @argument("command", type=str, default=[""], nargs="*", help="Commands to execute.") - @argument("-k", "--skip", type=bool, default=False, nargs="?", const=True, help="Skip adding session if it already exists") + @argument( + "-k", + "--skip", + type=bool, + default=False, + nargs="?", + const=True, + help="Skip adding session if it already exists", + ) @argument("-i", "--id", type=int, default=None, help="Session ID") - @argument("-e", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) " - "of the dataframe or not (pass False)") - + @argument( + "-e", + "--coerce", + type=str, + default=None, + help="Whether to automatically coerce the types (default, pass True if being explicit) " + "of the dataframe or not (pass False)", + ) @needs_local_scope @line_cell_magic @handle_expected_exceptions def spark(self, line, cell="", local_ns=None): """Magic to execute spark remotely. - This magic allows you to create a Livy Scala or Python session against a Livy endpoint. Every session can - be used to execute either Spark code or SparkSQL code by executing against the SQL context in the session. - When the SQL context is used, the result will be a Pandas dataframe of a sample of the results. - - If invoked with no subcommand, the cell will be executed against the specified session. - - Subcommands - ----------- - info - Display the available Livy sessions and other configurations for sessions. - add - Add a Livy session given a session name (-s), language (-l), and endpoint credentials. - The -k argument, if present, will skip adding this session if it already exists. - e.g. `%spark add -s test -l python -u https://sparkcluster.net/livy -t Kerberos -a u -p -k` - config - Override the livy session properties sent to Livy on session creation. All session creations will - contain these config settings from then on. - Expected value is a JSON key-value string to be sent as part of the Request Body for the POST /sessions - endpoint in Livy. - e.g. `%%spark config` - `{"driverMemory":"1000M", "executorCores":4}` - run - Run Spark code against a session. - e.g. `%%spark -s testsession` will execute the cell code against the testsession previously created - e.g. `%%spark -s testsession -c sql` will execute the SQL code against the testsession previously created - e.g. `%%spark -s testsession -c sql -o my_var` will execute the SQL code against the testsession - previously created and store the pandas dataframe created in the my_var variable in the - Python environment. - logs - Returns the logs for a given session. - e.g. `%spark logs -s testsession` will return the logs for the testsession previously created - delete - Delete a Livy session. - e.g. `%spark delete -s defaultlivy` - cleanup - Delete all Livy sessions created by the notebook. No arguments required. - e.g. `%spark cleanup` + This magic allows you to create a Livy Scala or Python session against a Livy endpoint. Every session can + be used to execute either Spark code or SparkSQL code by executing against the SQL context in the session. + When the SQL context is used, the result will be a Pandas dataframe of a sample of the results. + + If invoked with no subcommand, the cell will be executed against the specified session. + + Subcommands + ----------- + info + Display the available Livy sessions and other configurations for sessions. + add + Add a Livy session given a session name (-s), language (-l), and endpoint credentials. + The -k argument, if present, will skip adding this session if it already exists. + e.g. `%spark add -s test -l python -u https://sparkcluster.net/livy -t Kerberos -a u -p -k` + config + Override the livy session properties sent to Livy on session creation. All session creations will + contain these config settings from then on. + Expected value is a JSON key-value string to be sent as part of the Request Body for the POST /sessions + endpoint in Livy. + e.g. `%%spark config` + `{"driverMemory":"1000M", "executorCores":4}` + run + Run Spark code against a session. + e.g. `%%spark -s testsession` will execute the cell code against the testsession previously created + e.g. `%%spark -s testsession -c sql` will execute the SQL code against the testsession previously created + e.g. `%%spark -s testsession -c sql -o my_var` will execute the SQL code against the testsession + previously created and store the pandas dataframe created in the my_var variable in the + Python environment. + logs + Returns the logs for a given session. + e.g. `%spark logs -s testsession` will return the logs for the testsession previously created + delete + Delete a Livy session. + e.g. `%spark delete -s defaultlivy` + cleanup + Delete all Livy sessions created by the notebook. No arguments required. + e.g. `%spark cleanup` """ usage = "Please look at usage of %spark by executing `%spark?`." user_input = line args = parse_argstring_or_throw(self.spark, user_input) subcommand = args.command[0].lower() - + if args.auth is None: args.auth = conf.get_auth_value(args.user, args.password) else: @@ -121,7 +216,9 @@ def spark(self, line, cell="", local_ns=None): if subcommand == "info": if args.url is not None and args.id is not None: endpoint = Endpoint(args.url, initialize_auth(args)) - info_sessions = self.spark_controller.get_all_sessions_endpoint_info(endpoint) + info_sessions = self.spark_controller.get_all_sessions_endpoint_info( + endpoint + ) self._print_endpoint_info(info_sessions, args.id) else: self._print_local_info() @@ -131,12 +228,14 @@ def spark(self, line, cell="", local_ns=None): # add elif subcommand == "add": if args.url is None: - self.ipython_display.send_error("Need to supply URL argument (e.g. -u https://example.com/livyendpoint)") + self.ipython_display.send_error( + "Need to supply URL argument (e.g. -u https://example.com/livyendpoint)" + ) return name = args.session language = args.language - + endpoint = Endpoint(args.url, initialize_auth(args)) skip = args.skip @@ -149,13 +248,17 @@ def spark(self, line, cell="", local_ns=None): self.spark_controller.delete_session_by_name(args.session) elif args.url is not None: if args.id is None: - self.ipython_display.send_error("Must provide --id or -i option to delete session at endpoint from URL") + self.ipython_display.send_error( + "Must provide --id or -i option to delete session at endpoint from URL" + ) return endpoint = Endpoint(args.url, initialize_auth(args)) session_id = args.id self.spark_controller.delete_session_by_id(endpoint, session_id) else: - self.ipython_display.send_error("Subcommand 'delete' requires a session name or a URL and session ID") + self.ipython_display.send_error( + "Subcommand 'delete' requires a session name or a URL and session ID" + ) # cleanup elif subcommand == "cleanup": if args.url is not None: @@ -170,25 +273,52 @@ def spark(self, line, cell="", local_ns=None): elif len(subcommand) == 0: coerce = get_coerce_value(args.coerce) if args.context == CONTEXT_NAME_SPARK: - return self.execute_spark(cell, args.output, args.samplemethod, - args.maxrows, args.samplefraction, args.session, coerce) + return self.execute_spark( + cell, + args.output, + args.samplemethod, + args.maxrows, + args.samplefraction, + args.session, + coerce, + ) elif args.context == CONTEXT_NAME_SQL: - return self.execute_sqlquery(cell, args.samplemethod, args.maxrows, args.samplefraction, - args.session, args.output, args.quiet, coerce) + return self.execute_sqlquery( + cell, + args.samplemethod, + args.maxrows, + args.samplefraction, + args.session, + args.output, + args.quiet, + coerce, + ) else: - self.ipython_display.send_error("Context '{}' not found".format(args.context)) + self.ipython_display.send_error( + "Context '{}' not found".format(args.context) + ) # error else: - self.ipython_display.send_error("Subcommand '{}' not found. {}".format(subcommand, usage)) + self.ipython_display.send_error( + "Subcommand '{}' not found. {}".format(subcommand, usage) + ) def _print_local_info(self): - sessions_info = [" {}".format(i) for i in self.spark_controller.get_manager_sessions_str()] - print("""Info for running Spark: + sessions_info = [ + " {}".format(i) + for i in self.spark_controller.get_manager_sessions_str() + ] + print( + """Info for running Spark: Sessions: {} Session configs: {} -""".format("\n".join(sessions_info), conf.session_configs())) +""".format( + "\n".join(sessions_info), conf.session_configs() + ) + ) + def load_ipython_extension(ip): ip.register_magics(RemoteSparkMagics) diff --git a/sparkmagic/sparkmagic/magics/sparkmagicsbase.py b/sparkmagic/sparkmagic/magics/sparkmagicsbase.py index 5c5f62045..40c6b1802 100644 --- a/sparkmagic/sparkmagic/magics/sparkmagicsbase.py +++ b/sparkmagic/sparkmagic/magics/sparkmagicsbase.py @@ -16,48 +16,62 @@ from sparkmagic.utils.sparklogger import SparkLog from sparkmagic.utils.sparkevents import SparkEvents from sparkmagic.utils.utils import get_sessions_info_html -from sparkmagic.utils.constants import MIMETYPE_TEXT_HTML +from sparkmagic.utils.constants import MIMETYPE_TEXT_HTML from sparkmagic.livyclientlib.sparkcontroller import SparkController from sparkmagic.livyclientlib.sqlquery import SQLQuery from sparkmagic.livyclientlib.command import Command from sparkmagic.livyclientlib.sparkstorecommand import SparkStoreCommand from sparkmagic.livyclientlib.exceptions import SparkStatementException -from sparkmagic.livyclientlib.sendpandasdftosparkcommand import SendPandasDfToSparkCommand +from sparkmagic.livyclientlib.sendpandasdftosparkcommand import ( + SendPandasDfToSparkCommand, +) from sparkmagic.livyclientlib.sendstringtosparkcommand import SendStringToSparkCommand from sparkmagic.livyclientlib.exceptions import BadUserDataException -# How to display different cell content types in IPython -SparkOutputHandler = namedtuple('SparkOutputHandler', ['html', 'text', 'default']) +# How to display different cell content types in IPython +SparkOutputHandler = namedtuple("SparkOutputHandler", ["html", "text", "default"]) @magics_class class SparkMagicBase(Magics): - _STRING_VAR_TYPE = 'str' - _PANDAS_DATAFRAME_VAR_TYPE = 'df' + _STRING_VAR_TYPE = "str" + _PANDAS_DATAFRAME_VAR_TYPE = "df" _ALLOWED_LOCAL_TO_SPARK_TYPES = [_STRING_VAR_TYPE, _PANDAS_DATAFRAME_VAR_TYPE] def __init__(self, shell, data=None, spark_events=None): # You must call the parent constructor super(SparkMagicBase, self).__init__(shell) - self.logger = SparkLog(u"SparkMagics") + self.logger = SparkLog("SparkMagics") self.ipython_display = IpythonDisplay() self.spark_controller = SparkController(self.ipython_display) - self.logger.debug(u'Initialized spark magics.') + self.logger.debug("Initialized spark magics.") if spark_events is None: spark_events = SparkEvents() spark_events.emit_library_loaded_event() - def do_send_to_spark(self, cell, input_variable_name, var_type, output_variable_name, max_rows, session_name): + def do_send_to_spark( + self, + cell, + input_variable_name, + var_type, + output_variable_name, + max_rows, + session_name, + ): try: input_variable_value = self.shell.user_ns[input_variable_name] except KeyError: - raise BadUserDataException(u'Variable named {} not found.'.format(input_variable_name)) + raise BadUserDataException( + "Variable named {} not found.".format(input_variable_name) + ) if input_variable_value is None: - raise BadUserDataException(u'Value of {} is None!'.format(input_variable_name)) + raise BadUserDataException( + "Value of {} is None!".format(input_variable_name) + ) if not output_variable_name: output_variable_name = input_variable_name @@ -67,26 +81,52 @@ def do_send_to_spark(self, cell, input_variable_name, var_type, output_variable_ input_variable_type = var_type.lower() if input_variable_type == self._STRING_VAR_TYPE: - command = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) + command = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) elif input_variable_type == self._PANDAS_DATAFRAME_VAR_TYPE: - command = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, max_rows) + command = SendPandasDfToSparkCommand( + input_variable_name, + input_variable_value, + output_variable_name, + max_rows, + ) else: - raise BadUserDataException(u'Invalid or incorrect -t type. Available are: [{}]'.format(u','.join(self._ALLOWED_LOCAL_TO_SPARK_TYPES))) + raise BadUserDataException( + "Invalid or incorrect -t type. Available are: [{}]".format( + ",".join(self._ALLOWED_LOCAL_TO_SPARK_TYPES) + ) + ) (success, result, mime_type) = self.spark_controller.run_command(command, None) if not success: self.ipython_display.send_error(result) else: - self.ipython_display.write(u'Successfully passed \'{}\' as \'{}\' to Spark' - u' kernel'.format(input_variable_name, output_variable_name)) - - def execute_spark(self, cell, output_var, samplemethod, maxrows, samplefraction, session_name, coerce, output_handler=None): + self.ipython_display.write( + "Successfully passed '{}' as '{}' to Spark" + " kernel".format(input_variable_name, output_variable_name) + ) + + def execute_spark( + self, + cell, + output_var, + samplemethod, + maxrows, + samplefraction, + session_name, + coerce, + output_handler=None, + ): output_handler = output_handler or SparkOutputHandler( - html=self.ipython_display.html, - text=self.ipython_display.write, - default=self.ipython_display.display) - - (success, out, mimetype) = self.spark_controller.run_command(Command(cell), session_name) + html=self.ipython_display.html, + text=self.ipython_display.write, + default=self.ipython_display.display, + ) + + (success, out, mimetype) = self.spark_controller.run_command( + Command(cell), session_name + ) if not success: if conf.shutdown_session_on_spark_statement_errors(): self.spark_controller.cleanup() @@ -101,16 +141,31 @@ def execute_spark(self, cell, output_var, samplemethod, maxrows, samplefraction, else: output_handler.default(out) if output_var is not None: - spark_store_command = self._spark_store_command(output_var, samplemethod, maxrows, samplefraction, coerce) - df = self.spark_controller.run_command(spark_store_command, session_name) + spark_store_command = self._spark_store_command( + output_var, samplemethod, maxrows, samplefraction, coerce + ) + df = self.spark_controller.run_command( + spark_store_command, session_name + ) self.shell.user_ns[output_var] = df @staticmethod def _spark_store_command(output_var, samplemethod, maxrows, samplefraction, coerce): - return SparkStoreCommand(output_var, samplemethod, maxrows, samplefraction, coerce=coerce) - - def execute_sqlquery(self, cell, samplemethod, maxrows, samplefraction, - session, output_var, quiet, coerce): + return SparkStoreCommand( + output_var, samplemethod, maxrows, samplefraction, coerce=coerce + ) + + def execute_sqlquery( + self, + cell, + samplemethod, + maxrows, + samplefraction, + session, + output_var, + quiet, + coerce, + ): sqlquery = self._sqlquery(cell, samplemethod, maxrows, samplefraction, coerce) df = self.spark_controller.run_sqlquery(sqlquery, session) if output_var is not None: @@ -130,4 +185,4 @@ def _print_endpoint_info(self, info_sessions, current_session_id): html = get_sessions_info_html(info_sessions, current_session_id) self.ipython_display.html(html) else: - self.ipython_display.html(u'No active sessions.') + self.ipython_display.html("No active sessions.") diff --git a/sparkmagic/sparkmagic/serverextension/handlers.py b/sparkmagic/sparkmagic/serverextension/handlers.py index c94a258ca..8be29de15 100644 --- a/sparkmagic/sparkmagic/serverextension/handlers.py +++ b/sparkmagic/sparkmagic/serverextension/handlers.py @@ -19,7 +19,7 @@ class ReconnectHandler(IPythonHandler): @web.authenticated @gen.coroutine def post(self): - self.logger = SparkLog(u"ReconnectHandler") + self.logger = SparkLog("ReconnectHandler") spark_events = self._get_spark_events() @@ -35,13 +35,13 @@ def post(self): endpoint = None try: - path = self._get_argument_or_raise(data, 'path') - username = self._get_argument_or_raise(data, 'username') - password = self._get_argument_or_raise(data, 'password') - endpoint = self._get_argument_or_raise(data, 'endpoint') - auth = self._get_argument_if_exists(data, 'auth') + path = self._get_argument_or_raise(data, "path") + username = self._get_argument_or_raise(data, "username") + password = self._get_argument_or_raise(data, "password") + endpoint = self._get_argument_or_raise(data, "endpoint") + auth = self._get_argument_if_exists(data, "auth") if auth is None: - if username == '' and password == '': + if username == "" and password == "": auth = constants.NO_AUTH else: auth = constants.AUTH_BASIC @@ -59,7 +59,13 @@ def post(self): # Execute code client = kernel_manager.client() - code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, endpoint, username, password, auth) + code = "%{} -s {} -u {} -p {} -t {}".format( + KernelMagics._do_not_call_change_endpoint.__name__, + endpoint, + username, + password, + auth, + ) response_id = client.execute(code, silent=False, store_history=False) msg = client.get_shell_msg(response_id) @@ -69,16 +75,20 @@ def post(self): if successful_message: status_code = 200 else: - self.logger.error(u"Code to reconnect errored out: {}".format(error)) + self.logger.error("Code to reconnect errored out: {}".format(error)) status_code = 500 # Post execution info self.set_status(status_code) - self.finish(json.dumps(dict(success=successful_message, error=error), sort_keys=True)) - spark_events.emit_cluster_change_event(endpoint, status_code, successful_message, error) + self.finish( + json.dumps(dict(success=successful_message, error=error), sort_keys=True) + ) + spark_events.emit_cluster_change_event( + endpoint, status_code, successful_message, error + ) def _get_kernel_name(self, data): - kernel_name = self._get_argument_if_exists(data, 'kernelname') + kernel_name = self._get_argument_if_exists(data, "kernelname") self.logger.debug("Kernel name is {}".format(kernel_name)) if kernel_name is None: kernel_name = conf.server_extension_default_kernel_name() @@ -100,21 +110,25 @@ def _get_kernel_manager(self, path, kernel_name): kernel_id = None for session in sessions: - if session['notebook']['path'] == path: - session_id = session['id'] - kernel_id = session['kernel']['id'] - existing_kernel_name = session['kernel']['name'] + if session["notebook"]["path"] == path: + session_id = session["id"] + kernel_id = session["kernel"]["id"] + existing_kernel_name = session["kernel"]["name"] break if kernel_id is None: - self.logger.debug(u"Kernel not found. Starting a new kernel.") + self.logger.debug("Kernel not found. Starting a new kernel.") k_m = yield self._get_kernel_manager_new_session(path, kernel_name) elif existing_kernel_name != kernel_name: - self.logger.debug(u"Existing kernel name '{}' does not match requested '{}'. Starting a new kernel.".format(existing_kernel_name, kernel_name)) + self.logger.debug( + "Existing kernel name '{}' does not match requested '{}'. Starting a new kernel.".format( + existing_kernel_name, kernel_name + ) + ) self._delete_session(session_id) k_m = yield self._get_kernel_manager_new_session(path, kernel_name) else: - self.logger.debug(u"Kernel found. Restarting kernel.") + self.logger.debug("Kernel found. Restarting kernel.") k_m = self.kernel_manager.get_kernel(kernel_id) k_m.restart_kernel() @@ -122,7 +136,9 @@ def _get_kernel_manager(self, path, kernel_name): @gen.coroutine def _get_kernel_manager_new_session(self, path, kernel_name): - model_future = self.session_manager.create_session(kernel_name=kernel_name, path=path, type="notebook") + model_future = self.session_manager.create_session( + kernel_name=kernel_name, path=path, type="notebook" + ) model = yield model_future kernel_id = model["kernel"]["id"] self.logger.debug("Kernel created with id {}".format(str(kernel_id))) @@ -133,18 +149,18 @@ def _delete_session(self, session_id): self.session_manager.delete_session(session_id) def _msg_status(self, msg): - return msg['content']['status'] + return msg["content"]["status"] def _msg_successful(self, msg): - return self._msg_status(msg) == 'ok' + return self._msg_status(msg) == "ok" def _msg_error(self, msg): - if self._msg_status(msg) != 'error': + if self._msg_status(msg) != "error": return None - return u'{}:\n{}'.format(msg['content']['ename'], msg['content']['evalue']) + return "{}:\n{}".format(msg["content"]["ename"], msg["content"]["evalue"]) def _get_spark_events(self): - spark_events = getattr(self, 'spark_events', None) + spark_events = getattr(self, "spark_events", None) if spark_events is None: return SparkEvents() return spark_events @@ -154,10 +170,10 @@ def load_jupyter_server_extension(nb_app): nb_app.log.info("sparkmagic extension enabled!") web_app = nb_app.web_app - base_url = web_app.settings['base_url'] - host_pattern = '.*$' + base_url = web_app.settings["base_url"] + host_pattern = ".*$" - route_pattern_reconnect = url_path_join(base_url, '/reconnectsparkmagic') + route_pattern_reconnect = url_path_join(base_url, "/reconnectsparkmagic") handlers = [(route_pattern_reconnect, ReconnectHandler)] web_app.add_handlers(host_pattern, handlers) diff --git a/sparkmagic/sparkmagic/tests/test_command.py b/sparkmagic/sparkmagic/tests/test_command.py index 8e51c535d..5aae1a9b6 100644 --- a/sparkmagic/sparkmagic/tests/test_command.py +++ b/sparkmagic/sparkmagic/tests/test_command.py @@ -8,8 +8,13 @@ import sparkmagic.livyclientlib.exceptions import sparkmagic.utils.configuration as conf -from sparkmagic.utils.constants import SESSION_KIND_SPARK, MIMETYPE_IMAGE_PNG, MIMETYPE_TEXT_HTML, \ - MIMETYPE_TEXT_PLAIN, COMMAND_INTERRUPTED_MSG +from sparkmagic.utils.constants import ( + SESSION_KIND_SPARK, + MIMETYPE_IMAGE_PNG, + MIMETYPE_TEXT_HTML, + MIMETYPE_TEXT_PLAIN, + COMMAND_INTERRUPTED_MSG, +) from sparkmagic.livyclientlib.command import Command from sparkmagic.livyclientlib.livysession import LivySession from sparkmagic.livyclientlib.exceptions import SparkStatementCancelledException @@ -28,15 +33,21 @@ def _setup(): conf.override_all({}) -def _create_session(kind=SESSION_KIND_SPARK, session_id=-1, - http_client=None, spark_events=None): +def _create_session( + kind=SESSION_KIND_SPARK, session_id=-1, http_client=None, spark_events=None +): if http_client is None: http_client = MagicMock() if spark_events is None: spark_events = MagicMock() ipython_display = MagicMock() - session = LivySession(http_client, {"kind": kind, "heartbeatTimeoutInSecond": 60}, - ipython_display, session_id, spark_events) + session = LivySession( + http_client, + {"kind": kind, "heartbeatTimeoutInSecond": 60}, + ipython_display, + session_id, + spark_events, + ) return session @@ -69,14 +80,23 @@ def test_execute(): assert result[0] assert_equals(tls.TestLivySession.pi_result, result[1]) assert_equals(MIMETYPE_TEXT_PLAIN, result[2]) - spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - 0, True, "", "") + spark_events.emit_statement_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid, 0, True, "", "" + ) # Now try with PNG result: - http_client.get_statement.return_value = {"id":0,"state":"available","output":{"status":"ok", "execution_count":0,"data":{"text/plain":"", "image/png": b64encode(b"hello")}}} + http_client.get_statement.return_value = { + "id": 0, + "state": "available", + "output": { + "status": "ok", + "execution_count": 0, + "data": {"text/plain": "", "image/png": b64encode(b"hello")}, + }, + } result = command.execute(session) assert result[0] assert isinstance(result[1], Image) @@ -84,10 +104,18 @@ def test_execute(): assert_equals(MIMETYPE_IMAGE_PNG, result[2]) # Now try with HTML result: - http_client.get_statement.return_value = {"id":0,"state":"available","output":{"status":"ok", "execution_count":0,"data":{"text/html":"

out

"}}} + http_client.get_statement.return_value = { + "id": 0, + "state": "available", + "output": { + "status": "ok", + "execution_count": 0, + "data": {"text/html": "

out

"}, + }, + } result = command.execute(session) assert result[0] - assert_equals(u"

out

", result[1]) + assert_equals("

out

", result[1]) assert_equals(MIMETYPE_TEXT_HTML, result[2]) @@ -99,7 +127,12 @@ def test_execute_waiting(): http_client.post_session.return_value = tls.TestLivySession.session_create_json http_client.post_statement.return_value = tls.TestLivySession.post_statement_json http_client.get_session.return_value = tls.TestLivySession.ready_sessions_json - http_client.get_statement.side_effect = [tls.TestLivySession.waiting_statement_json, tls.TestLivySession.waiting_statement_json, tls.TestLivySession.ready_statement_json, tls.TestLivySession.ready_statement_json] + http_client.get_statement.side_effect = [ + tls.TestLivySession.waiting_statement_json, + tls.TestLivySession.waiting_statement_json, + tls.TestLivySession.ready_statement_json, + tls.TestLivySession.ready_statement_json, + ] session = _create_session(kind=kind, http_client=http_client) session.start() command = Command("command", spark_events=spark_events) @@ -111,11 +144,12 @@ def test_execute_waiting(): assert result[0] assert_equals(tls.TestLivySession.pi_result, result[1]) assert_equals(MIMETYPE_TEXT_PLAIN, result[2]) - spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - 0, True, "", "") + spark_events.emit_statement_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid, 0, True, "", "" + ) @with_setup(_setup) @@ -126,7 +160,9 @@ def test_execute_null_ouput(): http_client.post_session.return_value = tls.TestLivySession.session_create_json http_client.post_statement.return_value = tls.TestLivySession.post_statement_json http_client.get_session.return_value = tls.TestLivySession.ready_sessions_json - http_client.get_statement.return_value = tls.TestLivySession.ready_statement_null_output_json + http_client.get_statement.return_value = ( + tls.TestLivySession.ready_statement_null_output_json + ) session = _create_session(kind=kind, http_client=http_client) session.start() command = Command("command", spark_events=spark_events) @@ -136,13 +172,14 @@ def test_execute_null_ouput(): http_client.post_statement.assert_called_with(0, {"code": command.code}) http_client.get_statement.assert_called_with(0, 0) assert result[0] - assert_equals(u"", result[1]) + assert_equals("", result[1]) assert_equals(MIMETYPE_TEXT_PLAIN, result[2]) - spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - 0, True, "", "") + spark_events.emit_statement_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid, 0, True, "", "" + ) @with_setup(_setup) @@ -163,11 +200,19 @@ def test_execute_failure_wait_for_session_emits_event(): result = command.execute(session) assert False except ValueError as e: - spark_events.emit_statement_execution_start_event.assert_called_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - -1, False, "ValueError", "yo") + spark_events.emit_statement_execution_start_event.assert_called_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event.assert_called_once_with( + session.guid, + session.kind, + session.id, + command.guid, + -1, + False, + "ValueError", + "yo", + ) assert_equals(e, session.wait_for_idle.side_effect) @@ -183,17 +228,24 @@ def test_execute_failure_post_statement_emits_event(): session.wait_for_idle = MagicMock() command = Command("command", spark_events=spark_events) - http_client.post_statement.side_effect = KeyError('Something bad happened here') + http_client.post_statement.side_effect = KeyError("Something bad happened here") try: result = command.execute(session) assert False except KeyError as e: - spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event._assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - -1, False, "KeyError", - "Something bad happened here") + spark_events.emit_statement_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event._assert_called_once_with( + session.guid, + session.kind, + session.id, + command.guid, + -1, + False, + "KeyError", + "Something bad happened here", + ) assert_equals(e, http_client.post_statement.side_effect) @@ -209,18 +261,25 @@ def test_execute_failure_get_statement_output_emits_event(): session.start() session.wait_for_idle = MagicMock() command = Command("command", spark_events=spark_events) - command._get_statement_output = MagicMock(side_effect=AttributeError('OHHHH')) + command._get_statement_output = MagicMock(side_effect=AttributeError("OHHHH")) try: result = command.execute(session) assert False except AttributeError as e: - spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event._assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - -1, False, "AttributeError", - "OHHHH") + spark_events.emit_statement_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event._assert_called_once_with( + session.guid, + session.kind, + session.id, + command.guid, + -1, + False, + "AttributeError", + "OHHHH", + ) assert_equals(e, command._get_statement_output.side_effect) @@ -245,12 +304,19 @@ def test_execute_interrupted(): result = command.execute(session) assert False except KeyboardInterrupt as e: - spark_events.emit_statement_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, command.guid) - spark_events.emit_statement_execution_end_event._assert_called_once_with(session.guid, session.kind, - session.id, command.guid, - -1, False, "KeyboardInterrupt", - "") + spark_events.emit_statement_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, command.guid + ) + spark_events.emit_statement_execution_end_event._assert_called_once_with( + session.guid, + session.kind, + session.id, + command.guid, + -1, + False, + "KeyboardInterrupt", + "", + ) assert isinstance(e, SparkStatementCancelledException) assert_equals(str(e), COMMAND_INTERRUPTED_MSG) @@ -263,9 +329,11 @@ def test_execute_interrupted(): assert not stderr.getvalue() with _capture_stderr() as stderr: - mock_ipython._showtraceback(SparkStatementCancelledException, COMMAND_INTERRUPTED_MSG, MagicMock()) - mock_show_tb.assert_called_once() # still once - assert_equals(stderr.getvalue().strip(), COMMAND_INTERRUPTED_MSG) + mock_ipython._showtraceback( + SparkStatementCancelledException, COMMAND_INTERRUPTED_MSG, MagicMock() + ) + mock_show_tb.assert_called_once() # still once + assert_equals(stderr.getvalue().strip(), COMMAND_INTERRUPTED_MSG) except: assert False diff --git a/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py b/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py index 680b3303c..383a81318 100644 --- a/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py +++ b/sparkmagic/sparkmagic/tests/test_configurableretrypolicy.py @@ -14,17 +14,17 @@ def test_with_empty_list(): assert_equals(5, policy.seconds_to_sleep(4)) assert_equals(5, policy.seconds_to_sleep(5)) assert_equals(5, policy.seconds_to_sleep(6)) - + # Check based on retry count assert_equals(True, policy.should_retry(500, False, 0)) assert_equals(True, policy.should_retry(500, False, 4)) assert_equals(True, policy.should_retry(500, False, 5)) assert_equals(False, policy.should_retry(500, False, 6)) - + # Check based on status code assert_equals(False, policy.should_retry(201, False, 0)) assert_equals(False, policy.should_retry(201, False, 6)) - + # Check based on error assert_equals(True, policy.should_retry(201, True, 0)) assert_equals(True, policy.should_retry(201, True, 6)) @@ -45,11 +45,11 @@ def test_with_one_element_list(): assert_equals(True, policy.should_retry(500, False, 4)) assert_equals(True, policy.should_retry(500, False, 5)) assert_equals(False, policy.should_retry(500, False, 6)) - + # Check based on status code assert_equals(False, policy.should_retry(201, False, 0)) assert_equals(False, policy.should_retry(201, False, 6)) - + # Check based on error assert_equals(True, policy.should_retry(201, True, 0)) assert_equals(True, policy.should_retry(201, True, 6)) @@ -76,11 +76,11 @@ def test_with_default_values(): assert_equals(True, policy.should_retry(500, False, 7)) assert_equals(True, policy.should_retry(500, False, 8)) assert_equals(False, policy.should_retry(500, False, 9)) - + # Check based on status code assert_equals(False, policy.should_retry(201, False, 0)) assert_equals(False, policy.should_retry(201, False, 9)) - + # Check based on error assert_equals(True, policy.should_retry(201, True, 0)) assert_equals(True, policy.should_retry(201, True, 9)) @@ -89,7 +89,7 @@ def test_with_default_values(): def test_with_negative_values(): times = [0.1, -1] max_retries = 5 - + try: policy = ConfigurableRetryPolicy(times, max_retries) assert False diff --git a/sparkmagic/sparkmagic/tests/test_configuration.py b/sparkmagic/sparkmagic/tests/test_configuration.py index 52ca0e2f5..aff41a3da 100644 --- a/sparkmagic/sparkmagic/tests/test_configuration.py +++ b/sparkmagic/sparkmagic/tests/test_configuration.py @@ -9,65 +9,104 @@ def _setup(): conf.override_all({}) - + @with_setup(_setup) def test_configuration_override_base64_password(): - kpc = { 'username': 'U', 'password': 'P', 'base64_password': 'cGFzc3dvcmQ=', 'url': 'L', "auth": AUTH_BASIC } - overrides = { conf.kernel_python_credentials.__name__: kpc } + kpc = { + "username": "U", + "password": "P", + "base64_password": "cGFzc3dvcmQ=", + "url": "L", + "auth": AUTH_BASIC, + } + overrides = {conf.kernel_python_credentials.__name__: kpc} conf.override_all(overrides) conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1) - assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc, - conf.livy_session_startup_timeout_seconds.__name__: 1 }) + assert_equals( + conf.d, + { + conf.kernel_python_credentials.__name__: kpc, + conf.livy_session_startup_timeout_seconds.__name__: 1, + }, + ) assert_equals(conf.livy_session_startup_timeout_seconds(), 1) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': 'password', 'url': 'L', 'auth': AUTH_BASIC }) + assert_equals( + conf.base64_kernel_python_credentials(), + {"username": "U", "password": "password", "url": "L", "auth": AUTH_BASIC}, + ) @with_setup(_setup) def test_configuration_auth_missing_basic_auth(): - kpc = { 'username': 'U', 'password': 'P', 'url': 'L'} - overrides = { conf.kernel_python_credentials.__name__: kpc } + kpc = {"username": "U", "password": "P", "url": "L"} + overrides = {conf.kernel_python_credentials.__name__: kpc} conf.override_all(overrides) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': 'P', 'url': 'L', 'auth': AUTH_BASIC }) + assert_equals( + conf.base64_kernel_python_credentials(), + {"username": "U", "password": "P", "url": "L", "auth": AUTH_BASIC}, + ) @with_setup(_setup) def test_configuration_auth_missing_no_auth(): - kpc = { 'username': '', 'password': '', 'url': 'L'} - overrides = { conf.kernel_python_credentials.__name__: kpc } + kpc = {"username": "", "password": "", "url": "L"} + overrides = {conf.kernel_python_credentials.__name__: kpc} conf.override_all(overrides) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': '', 'password': '', 'url': 'L', 'auth': NO_AUTH }) + assert_equals( + conf.base64_kernel_python_credentials(), + {"username": "", "password": "", "url": "L", "auth": NO_AUTH}, + ) @with_setup(_setup) def test_configuration_override_fallback_to_password(): - kpc = { 'username': 'U', 'password': 'P', 'url': 'L', 'auth': NO_AUTH } - overrides = { conf.kernel_python_credentials.__name__: kpc } + kpc = {"username": "U", "password": "P", "url": "L", "auth": NO_AUTH} + overrides = {conf.kernel_python_credentials.__name__: kpc} conf.override_all(overrides) conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1) - assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc, - conf.livy_session_startup_timeout_seconds.__name__: 1 }) + assert_equals( + conf.d, + { + conf.kernel_python_credentials.__name__: kpc, + conf.livy_session_startup_timeout_seconds.__name__: 1, + }, + ) assert_equals(conf.livy_session_startup_timeout_seconds(), 1) assert_equals(conf.base64_kernel_python_credentials(), kpc) @with_setup(_setup) def test_configuration_override_work_with_empty_password(): - kpc = { 'username': 'U', 'base64_password': '', 'password': '', 'url': '', 'auth': AUTH_BASIC } - overrides = { conf.kernel_python_credentials.__name__: kpc } + kpc = { + "username": "U", + "base64_password": "", + "password": "", + "url": "", + "auth": AUTH_BASIC, + } + overrides = {conf.kernel_python_credentials.__name__: kpc} conf.override_all(overrides) conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1) - assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc, - conf.livy_session_startup_timeout_seconds.__name__: 1 }) + assert_equals( + conf.d, + { + conf.kernel_python_credentials.__name__: kpc, + conf.livy_session_startup_timeout_seconds.__name__: 1, + }, + ) assert_equals(conf.livy_session_startup_timeout_seconds(), 1) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': '', 'url': '', 'auth': AUTH_BASIC }) + assert_equals( + conf.base64_kernel_python_credentials(), + {"username": "U", "password": "", "url": "", "auth": AUTH_BASIC}, + ) @raises(BadUserConfigurationException) @with_setup(_setup) def test_configuration_raise_error_for_bad_base64_password(): - kpc = { 'username': 'U', 'base64_password': 'P', 'url': 'L' } - overrides = { conf.kernel_python_credentials.__name__: kpc } + kpc = {"username": "U", "base64_password": "P", "url": "L"} + overrides = {conf.kernel_python_credentials.__name__: kpc} conf.override_all(overrides) conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1) conf.base64_kernel_python_credentials() @@ -75,6 +114,15 @@ def test_configuration_raise_error_for_bad_base64_password(): @with_setup(_setup) def test_share_config_between_pyspark_and_pyspark3(): - kpc = { 'username': 'U', 'password': 'P', 'base64_password': 'cGFzc3dvcmQ=', 'url': 'L', 'auth': AUTH_BASIC } - overrides = { conf.kernel_python_credentials.__name__: kpc } - assert_equals(conf.base64_kernel_python3_credentials(), conf.base64_kernel_python_credentials()) + kpc = { + "username": "U", + "password": "P", + "base64_password": "cGFzc3dvcmQ=", + "url": "L", + "auth": AUTH_BASIC, + } + overrides = {conf.kernel_python_credentials.__name__: kpc} + assert_equals( + conf.base64_kernel_python3_credentials(), + conf.base64_kernel_python_credentials(), + ) diff --git a/sparkmagic/sparkmagic/tests/test_endpoint.py b/sparkmagic/sparkmagic/tests/test_endpoint.py index c3491cf84..1a2259d64 100644 --- a/sparkmagic/sparkmagic/tests/test_endpoint.py +++ b/sparkmagic/sparkmagic/tests/test_endpoint.py @@ -5,20 +5,30 @@ from sparkmagic.auth.basic import Basic from sparkmagic.auth.kerberos import Kerberos + def test_equality(): basic_auth1 = Basic() basic_auth2 = Basic() kerberos_auth1 = Kerberos() kerberos_auth2 = Kerberos() - assert_equals(Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2)) - assert_equals(Endpoint("http://url.com", kerberos_auth1), Endpoint("http://url.com", kerberos_auth2)) + assert_equals( + Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2) + ) + assert_equals( + Endpoint("http://url.com", kerberos_auth1), + Endpoint("http://url.com", kerberos_auth2), + ) + def test_inequality(): basic_auth1 = Basic() basic_auth2 = Basic() - basic_auth1.username = 'user' - basic_auth2.username = 'different_user' - assert_not_equal(Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2)) + basic_auth1.username = "user" + basic_auth2.username = "different_user" + assert_not_equal( + Endpoint("http://url.com", basic_auth1), Endpoint("http://url.com", basic_auth2) + ) + def test_invalid_url(): basic_auth = Basic() diff --git a/sparkmagic/sparkmagic/tests/test_exceptions.py b/sparkmagic/sparkmagic/tests/test_exceptions.py index a0ce92300..db2e91334 100644 --- a/sparkmagic/sparkmagic/tests/test_exceptions.py +++ b/sparkmagic/sparkmagic/tests/test_exceptions.py @@ -21,7 +21,7 @@ def _setup(): @with_setup(_setup) def test_handle_expected_exceptions(): mock_method = MagicMock() - mock_method.__name__ = 'MockMethod' + mock_method.__name__ = "MockMethod" decorated = handle_expected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) @@ -33,51 +33,48 @@ def test_handle_expected_exceptions(): @with_setup(_setup) def test_handle_expected_exceptions_handle(): - conf.override_all({ - 'all_errors_are_fatal': False - }) - mock_method = MagicMock(side_effect=LivyUnexpectedStatusException('ridiculous')) - mock_method.__name__ = 'MockMethod2' + conf.override_all({"all_errors_are_fatal": False}) + mock_method = MagicMock(side_effect=LivyUnexpectedStatusException("ridiculous")) + mock_method.__name__ = "MockMethod2" decorated = handle_expected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) - result = decorated(self, 1, kwarg='foo') + result = decorated(self, 1, kwarg="foo") assert_is(result, None) assert_equals(ipython_display.send_error.call_count, 1) - mock_method.assert_called_once_with(self, 1, kwarg='foo') + mock_method.assert_called_once_with(self, 1, kwarg="foo") @raises(ValueError) @with_setup(_setup) def test_handle_expected_exceptions_throw(): - mock_method = MagicMock(side_effect=ValueError('HALP')) - mock_method.__name__ = 'mock_meth' + mock_method = MagicMock(side_effect=ValueError("HALP")) + mock_method.__name__ = "mock_meth" decorated = handle_expected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) - result = decorated(self, 1, kwarg='foo') + result = decorated(self, 1, kwarg="foo") @raises(LivyUnexpectedStatusException) @with_setup(_setup) def test_handle_expected_exceptions_throws_if_all_errors_fatal(): - conf.override_all({ - 'all_errors_are_fatal': True - }) - mock_method = MagicMock(side_effect=LivyUnexpectedStatusException('Oh no!')) - mock_method.__name__ = 'mock_meth' + conf.override_all({"all_errors_are_fatal": True}) + mock_method = MagicMock(side_effect=LivyUnexpectedStatusException("Oh no!")) + mock_method.__name__ = "mock_meth" decorated = handle_expected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) - result = decorated(self, 1, kwarg='foo') + result = decorated(self, 1, kwarg="foo") # test wrap with unexpected to true + @with_setup(_setup) def test_wrap_unexpected_exceptions(): mock_method = MagicMock() - mock_method.__name__ = 'tos' + mock_method.__name__ = "tos" decorated = wrap_unexpected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) @@ -89,27 +86,26 @@ def test_wrap_unexpected_exceptions(): @with_setup(_setup) def test_wrap_unexpected_exceptions_handle(): - mock_method = MagicMock(side_effect=ValueError('~~~~~~')) - mock_method.__name__ = 'tos' + mock_method = MagicMock(side_effect=ValueError("~~~~~~")) + mock_method.__name__ = "tos" decorated = wrap_unexpected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) - result = decorated(self, 'FOOBAR', FOOBAR='FOOBAR') + result = decorated(self, "FOOBAR", FOOBAR="FOOBAR") assert_is(result, None) assert_equals(ipython_display.send_error.call_count, 1) - mock_method.assert_called_once_with(self, 'FOOBAR', FOOBAR='FOOBAR') + mock_method.assert_called_once_with(self, "FOOBAR", FOOBAR="FOOBAR") + # test wrap with unexpected to true # test wrap with all to true @raises(ValueError) @with_setup(_setup) def test_wrap_unexpected_exceptions_throws_if_all_errors_fatal(): - conf.override_all({ - 'all_errors_are_fatal': True - }) - mock_method = MagicMock(side_effect=ValueError('~~~~~~')) - mock_method.__name__ = 'tos' + conf.override_all({"all_errors_are_fatal": True}) + mock_method = MagicMock(side_effect=ValueError("~~~~~~")) + mock_method.__name__ = "tos" decorated = wrap_unexpected_exceptions(mock_method) assert_equals(decorated.__name__, mock_method.__name__) - result = decorated(self, 'FOOBAR', FOOBAR='FOOBAR') + result = decorated(self, "FOOBAR", FOOBAR="FOOBAR") diff --git a/sparkmagic/sparkmagic/tests/test_handlers.py b/sparkmagic/sparkmagic/tests/test_handlers.py index 4d01734fd..fb441d88c 100644 --- a/sparkmagic/sparkmagic/tests/test_handlers.py +++ b/sparkmagic/sparkmagic/tests/test_handlers.py @@ -23,24 +23,33 @@ class TestSparkMagicHandler(AsyncTestCase): client = None session_list = None spark_events = None - path = 'some_path.ipynb' - kernel_id = '1' - kernel_name = 'pysparkkernel' - session_id = '1' - username = 'username' - password = 'password' - endpoint = 'http://endpoint.com' + path = "some_path.ipynb" + kernel_id = "1" + kernel_name = "pysparkkernel" + session_id = "1" + username = "username" + password = "password" + endpoint = "http://endpoint.com" auth = constants.AUTH_BASIC - response_id = '0' - good_msg = dict(content=dict(status='ok')) - bad_msg = dict(content=dict(status='error', ename='SyntaxError', evalue='oh no!')) + response_id = "0" + good_msg = dict(content=dict(status="ok")) + bad_msg = dict(content=dict(status="error", ename="SyntaxError", evalue="oh no!")) request = None def create_session_dict(self, path, kernel_id): - return dict(notebook=dict(path=path), kernel=dict(id=kernel_id, name=self.kernel_name), id=self.session_id) + return dict( + notebook=dict(path=path), + kernel=dict(id=kernel_id, name=self.kernel_name), + id=self.session_id, + ) def get_argument(self, key): - return dict(username=self.username, password=self.password, endpoint=self.endpoint, path=self.path)[key] + return dict( + username=self.username, + password=self.password, + endpoint=self.endpoint, + path=self.path, + )[key] def setUp(self): # Mock kernel manager @@ -50,20 +59,32 @@ def setUp(self): self.individual_kernel_manager = MagicMock() self.individual_kernel_manager.client = MagicMock(return_value=self.client) self.kernel_manager = MagicMock() - self.kernel_manager.get_kernel = MagicMock(return_value=self.individual_kernel_manager) + self.kernel_manager.get_kernel = MagicMock( + return_value=self.individual_kernel_manager + ) # Mock session manager self.session_list = [self.create_session_dict(self.path, self.kernel_id)] self.session_manager = MagicMock() self.session_manager.list_sessions = MagicMock(return_value=self.session_list) - self.session_manager.create_session = MagicMock(return_value=self.create_session_dict(self.path, self.kernel_id)) + self.session_manager.create_session = MagicMock( + return_value=self.create_session_dict(self.path, self.kernel_id) + ) # Mock spark events self.spark_events = MagicMock() # Mock request self.request = MagicMock() - self.request.body = json.dumps({"path": self.path, "username": self.username, "password": self.password, "endpoint": self.endpoint, "auth": self.auth}) + self.request.body = json.dumps( + { + "path": self.path, + "username": self.username, + "password": self.password, + "endpoint": self.endpoint, + "auth": self.auth, + } + ) # Create mocked reconnect_handler ReconnectHandler.__bases__ = (SimpleObject,) @@ -73,15 +94,15 @@ def setUp(self): self.reconnect_handler.kernel_manager = self.kernel_manager self.reconnect_handler.set_status = MagicMock() self.reconnect_handler.finish = MagicMock() - self.reconnect_handler.current_user = 'alex' + self.reconnect_handler.current_user = "alex" self.reconnect_handler.request = self.request self.reconnect_handler.logger = MagicMock() super(TestSparkMagicHandler, self).setUp() def test_msg_status(self): - assert_equals(self.reconnect_handler._msg_status(self.good_msg), 'ok') - assert_equals(self.reconnect_handler._msg_status(self.bad_msg), 'error') + assert_equals(self.reconnect_handler._msg_status(self.good_msg), "ok") + assert_equals(self.reconnect_handler._msg_status(self.bad_msg), "error") def test_msg_successful(self): assert_equals(self.reconnect_handler._msg_successful(self.good_msg), True) @@ -89,7 +110,10 @@ def test_msg_successful(self): def test_msg_error(self): assert_equals(self.reconnect_handler._msg_error(self.good_msg), None) - assert_equals(self.reconnect_handler._msg_error(self.bad_msg), u'{}:\n{}'.format('SyntaxError', 'oh no!')) + assert_equals( + self.reconnect_handler._msg_error(self.bad_msg), + "{}:\n{}".format("SyntaxError", "oh no!"), + ) @gen_test def test_post_no_json(self): @@ -101,7 +125,9 @@ def test_post_no_json(self): msg = "Invalid JSON in request body." self.reconnect_handler.set_status.assert_called_once_with(400) self.reconnect_handler.finish.assert_called_once_with(msg) - self.spark_events.emit_cluster_change_event.assert_called_once_with(None, 400, False, msg) + self.spark_events.emit_cluster_change_event.assert_called_once_with( + None, 400, False, msg + ) @gen_test def test_post_no_key(self): @@ -110,15 +136,24 @@ def test_post_no_key(self): res = yield self.reconnect_handler.post() assert_equals(res, None) - msg = 'HTTP 400: Bad Request (Missing argument path)' + msg = "HTTP 400: Bad Request (Missing argument path)" self.reconnect_handler.set_status.assert_called_once_with(400) self.reconnect_handler.finish.assert_called_once_with(msg) - self.spark_events.emit_cluster_change_event.assert_called_once_with(None, 400, False, msg) + self.spark_events.emit_cluster_change_event.assert_called_once_with( + None, 400, False, msg + ) - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager') + @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager") @gen_test def test_post_existing_kernel_with_auth_missing_no_auth(self, _get_kernel_manager): - self.request.body = json.dumps({ "path": self.path, "username": '', "password": '', "endpoint": self.endpoint }) + self.request.body = json.dumps( + { + "path": self.path, + "username": "", + "password": "", + "endpoint": self.endpoint, + } + ) kernel_manager_future = Future() kernel_manager_future.set_result(self.individual_kernel_manager) _get_kernel_manager.return_value = kernel_manager_future @@ -126,16 +161,37 @@ def test_post_existing_kernel_with_auth_missing_no_auth(self, _get_kernel_manage res = yield self.reconnect_handler.post() assert_equals(res, None) - code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, '', '', constants.NO_AUTH) - self.client.execute.assert_called_once_with(code, silent=False, store_history=False) + code = "%{} -s {} -u {} -p {} -t {}".format( + KernelMagics._do_not_call_change_endpoint.__name__, + self.endpoint, + "", + "", + constants.NO_AUTH, + ) + self.client.execute.assert_called_once_with( + code, silent=False, store_history=False + ) self.reconnect_handler.set_status.assert_called_once_with(200) - self.reconnect_handler.finish.assert_called_once_with('{"error": null, "success": true}') - self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 200, True, None) - - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager') + self.reconnect_handler.finish.assert_called_once_with( + '{"error": null, "success": true}' + ) + self.spark_events.emit_cluster_change_event.assert_called_once_with( + self.endpoint, 200, True, None + ) + + @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager") @gen_test - def test_post_existing_kernel_with_auth_missing_basic_auth(self, _get_kernel_manager): - self.request.body = json.dumps({ "path": self.path, "username": self.username, "password": self.password, "endpoint": self.endpoint}) + def test_post_existing_kernel_with_auth_missing_basic_auth( + self, _get_kernel_manager + ): + self.request.body = json.dumps( + { + "path": self.path, + "username": self.username, + "password": self.password, + "endpoint": self.endpoint, + } + ) kernel_manager_future = Future() kernel_manager_future.set_result(self.individual_kernel_manager) _get_kernel_manager.return_value = kernel_manager_future @@ -143,13 +199,25 @@ def test_post_existing_kernel_with_auth_missing_basic_auth(self, _get_kernel_man res = yield self.reconnect_handler.post() assert_equals(res, None) - code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, self.username, self.password, constants.AUTH_BASIC) - self.client.execute.assert_called_once_with(code, silent=False, store_history=False) + code = "%{} -s {} -u {} -p {} -t {}".format( + KernelMagics._do_not_call_change_endpoint.__name__, + self.endpoint, + self.username, + self.password, + constants.AUTH_BASIC, + ) + self.client.execute.assert_called_once_with( + code, silent=False, store_history=False + ) self.reconnect_handler.set_status.assert_called_once_with(200) - self.reconnect_handler.finish.assert_called_once_with('{"error": null, "success": true}') - self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 200, True, None) - - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager') + self.reconnect_handler.finish.assert_called_once_with( + '{"error": null, "success": true}' + ) + self.spark_events.emit_cluster_change_event.assert_called_once_with( + self.endpoint, 200, True, None + ) + + @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager") @gen_test def test_post_existing_kernel(self, _get_kernel_manager): kernel_manager_future = Future() @@ -159,13 +227,25 @@ def test_post_existing_kernel(self, _get_kernel_manager): res = yield self.reconnect_handler.post() assert_equals(res, None) - code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, self.username, self.password, self.auth) - self.client.execute.assert_called_once_with(code, silent=False, store_history=False) + code = "%{} -s {} -u {} -p {} -t {}".format( + KernelMagics._do_not_call_change_endpoint.__name__, + self.endpoint, + self.username, + self.password, + self.auth, + ) + self.client.execute.assert_called_once_with( + code, silent=False, store_history=False + ) self.reconnect_handler.set_status.assert_called_once_with(200) - self.reconnect_handler.finish.assert_called_once_with('{"error": null, "success": true}') - self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 200, True, None) - - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager') + self.reconnect_handler.finish.assert_called_once_with( + '{"error": null, "success": true}' + ) + self.spark_events.emit_cluster_change_event.assert_called_once_with( + self.endpoint, 200, True, None + ) + + @patch("sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager") @gen_test def test_post_existing_kernel_failed(self, _get_kernel_manager): kernel_manager_future = Future() @@ -176,48 +256,80 @@ def test_post_existing_kernel_failed(self, _get_kernel_manager): res = yield self.reconnect_handler.post() assert_equals(res, None) - code = '%{} -s {} -u {} -p {} -t {}'.format(KernelMagics._do_not_call_change_endpoint.__name__, self.endpoint, self.username, self.password, self.auth) - self.client.execute.assert_called_once_with(code, silent=False, store_history=False) + code = "%{} -s {} -u {} -p {} -t {}".format( + KernelMagics._do_not_call_change_endpoint.__name__, + self.endpoint, + self.username, + self.password, + self.auth, + ) + self.client.execute.assert_called_once_with( + code, silent=False, store_history=False + ) self.reconnect_handler.set_status.assert_called_once_with(500) - self.reconnect_handler.finish.assert_called_once_with('{"error": "SyntaxError:\\noh no!", "success": false}') - self.spark_events.emit_cluster_change_event.assert_called_once_with(self.endpoint, 500, False, "SyntaxError:\noh no!") - - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session') + self.reconnect_handler.finish.assert_called_once_with( + '{"error": "SyntaxError:\\noh no!", "success": false}' + ) + self.spark_events.emit_cluster_change_event.assert_called_once_with( + self.endpoint, 500, False, "SyntaxError:\noh no!" + ) + + @patch( + "sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session" + ) @gen_test - def test_get_kernel_manager_no_existing_kernel(self, _get_kernel_manager_new_session): + def test_get_kernel_manager_no_existing_kernel( + self, _get_kernel_manager_new_session + ): different_path = "different_path.ipynb" km_future = Future() km_future.set_result(self.individual_kernel_manager) _get_kernel_manager_new_session.return_value = km_future - - km = yield self.reconnect_handler._get_kernel_manager(different_path, self.kernel_name) + + km = yield self.reconnect_handler._get_kernel_manager( + different_path, self.kernel_name + ) assert_equals(self.individual_kernel_manager, km) self.individual_kernel_manager.restart_kernel.assert_not_called() self.kernel_manager.get_kernel.assert_not_called() - _get_kernel_manager_new_session.assert_called_once_with(different_path, self.kernel_name) + _get_kernel_manager_new_session.assert_called_once_with( + different_path, self.kernel_name + ) - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session') + @patch( + "sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session" + ) @gen_test def test_get_kernel_manager_existing_kernel(self, _get_kernel_manager_new_session): - km = yield self.reconnect_handler._get_kernel_manager(self.path, self.kernel_name) + km = yield self.reconnect_handler._get_kernel_manager( + self.path, self.kernel_name + ) assert_equals(self.individual_kernel_manager, km) self.individual_kernel_manager.restart_kernel.assert_called_once_with() _get_kernel_manager_new_session.assert_not_called() - @patch('sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session') + @patch( + "sparkmagic.serverextension.handlers.ReconnectHandler._get_kernel_manager_new_session" + ) @gen_test - def test_get_kernel_manager_different_kernel_type(self, _get_kernel_manager_new_session): + def test_get_kernel_manager_different_kernel_type( + self, _get_kernel_manager_new_session + ): different_kernel = "sparkkernel" km_future = Future() km_future.set_result(self.individual_kernel_manager) _get_kernel_manager_new_session.return_value = km_future - km = yield self.reconnect_handler._get_kernel_manager(self.path, different_kernel) + km = yield self.reconnect_handler._get_kernel_manager( + self.path, different_kernel + ) assert_equals(self.individual_kernel_manager, km) self.individual_kernel_manager.restart_kernel.assert_not_called() self.kernel_manager.get_kernel.assert_not_called() - _get_kernel_manager_new_session.assert_called_once_with(self.path, different_kernel) + _get_kernel_manager_new_session.assert_called_once_with( + self.path, different_kernel + ) self.session_manager.delete_session.assert_called_once_with(self.session_id) diff --git a/sparkmagic/sparkmagic/tests/test_heartbeatthread.py b/sparkmagic/sparkmagic/tests/test_heartbeatthread.py index 4bdfc6bc5..df970f080 100644 --- a/sparkmagic/sparkmagic/tests/test_heartbeatthread.py +++ b/sparkmagic/sparkmagic/tests/test_heartbeatthread.py @@ -10,40 +10,40 @@ def test_create_thread(): refresh_seconds = 1 retry_seconds = 2 heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds) - + assert_equals(heartbeat_thread.livy_session, session) assert_equals(heartbeat_thread.refresh_seconds, refresh_seconds) assert_equals(heartbeat_thread.retry_seconds, retry_seconds) - - + + def test_run_once(): session = MagicMock() refresh_seconds = 0.1 retry_seconds = 2 heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds, 1) - + heartbeat_thread.start() sleep(0.15) heartbeat_thread.stop() - + session.refresh_status_and_info.assert_called_once_with() assert heartbeat_thread.livy_session is None - - + + def test_run_stops(): session = MagicMock() refresh_seconds = 0.01 retry_seconds = 2 heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds) - + heartbeat_thread.start() sleep(0.1) heartbeat_thread.stop() - + assert session.refresh_status_and_info.called assert heartbeat_thread.livy_session is None - - + + def test_run_retries(): msg = "oh noes!" session = MagicMock() @@ -51,16 +51,16 @@ def test_run_retries(): refresh_seconds = 0.1 retry_seconds = 0.1 heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds, 1) - + heartbeat_thread.start() sleep(0.15) heartbeat_thread.stop() - + session.refresh_status_and_info.assert_called_once_with() session.logger.error.assert_called_once_with(msg) assert heartbeat_thread.livy_session is None - - + + def test_run_retries_stops(): msg = "oh noes!" session = MagicMock() @@ -68,12 +68,11 @@ def test_run_retries_stops(): refresh_seconds = 0.01 retry_seconds = 0.01 heartbeat_thread = _HeartbeatThread(session, refresh_seconds, retry_seconds) - + heartbeat_thread.start() sleep(0.1) heartbeat_thread.stop() - + assert session.refresh_status_and_info.called assert session.logger.error.called assert heartbeat_thread.livy_session is None - \ No newline at end of file diff --git a/sparkmagic/sparkmagic/tests/test_kernel_magics.py b/sparkmagic/sparkmagic/tests/test_kernel_magics.py index 98cd208e7..5a15bc910 100644 --- a/sparkmagic/sparkmagic/tests/test_kernel_magics.py +++ b/sparkmagic/sparkmagic/tests/test_kernel_magics.py @@ -7,9 +7,16 @@ import sparkmagic.utils.constants as constants from sparkmagic.kernels.kernelmagics import KernelMagics, Namespace from sparkmagic.magics.remotesparkmagics import RemoteSparkMagics -from sparkmagic.livyclientlib.exceptions import LivyClientTimeoutException, BadUserDataException,\ - LivyUnexpectedStatusException, SessionManagementException,\ - HttpClientException, DataFrameParseException, SqlContextNotFoundException, SparkStatementException +from sparkmagic.livyclientlib.exceptions import ( + LivyClientTimeoutException, + BadUserDataException, + LivyUnexpectedStatusException, + SessionManagementException, + HttpClientException, + DataFrameParseException, + SqlContextNotFoundException, + SparkStatementException, +) from sparkmagic.livyclientlib.endpoint import Endpoint from sparkmagic.livyclientlib.command import Command from sparkmagic.auth.basic import Basic @@ -21,6 +28,7 @@ ipython_display = MagicMock() spark_events = None + @magics_class class TestKernelMagics(KernelMagics): def __init__(self, shell, data=None, spark_events=None): @@ -43,7 +51,7 @@ def _setup(): magic.shell = shell = MagicMock() magic.ipython_display = ipython_display = MagicMock() magic.spark_controller = spark_controller = MagicMock() - magic._generate_uuid = MagicMock(return_value='0000') + magic._generate_uuid = MagicMock(return_value="0000") def _teardown(): @@ -65,8 +73,12 @@ def test_start_session(): assert ret assert magic.session_started - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) # Call a second time ret = magic._do_not_call_start_session(line) @@ -95,12 +107,11 @@ def test_start_session_times_out(): assert not magic.session_started assert_equals(ipython_display.send_error.call_count, 1) + @with_setup(_setup, _teardown) @raises(LivyClientTimeoutException) def test_start_session_times_out_all_errors_are_fatal(): - conf.override_all({ - "all_errors_are_fatal": True - }) + conf.override_all({"all_errors_are_fatal": True}) line = "" spark_controller.add_session = MagicMock(side_effect=LivyClientTimeoutException) @@ -129,14 +140,17 @@ def test_delete_session(): def test_delete_session_expected_exception(): line = "" magic.session_started = True - spark_controller.delete_session_by_name.side_effect = BadUserDataException('hey') + spark_controller.delete_session_by_name.side_effect = BadUserDataException("hey") magic._do_not_call_delete_session(line) assert not magic.session_started spark_controller.delete_session_by_name.assert_called_once_with(magic.session_name) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG - .format(spark_controller.delete_session_by_name.side_effect)) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format( + spark_controller.delete_session_by_name.side_effect + ) + ) @with_setup(_setup, _teardown) @@ -174,30 +188,35 @@ def test_change_language_not_valid(): assert_equals(constants.LANG_PYTHON, magic.language) assert_equals(Endpoint("url", None), magic.endpoint) + @with_setup(_setup, _teardown) def test_change_endpoint(): - s = 'server' - u = 'user' - p = 'password' + s = "server" + u = "user" + p = "password" t = constants.AUTH_BASIC line = "-s {} -u {} -p {} -t {}".format(s, u, p, t) magic._do_not_call_change_endpoint(line) - args = Namespace(auth='Basic_Access', password='password', url='server', user='user') + args = Namespace( + auth="Basic_Access", password="password", url="server", user="user" + ) auth_instance = initialize_auth(args) endpoint = Endpoint(s, auth_instance) assert_equals(endpoint.url, magic.endpoint.url) assert_equals(Endpoint(s, auth_instance), magic.endpoint) + @with_setup(_setup, _teardown) @raises(BadUserDataException) def test_change_endpoint_session_started(): - u = 'user' - p = 'password' - s = 'server' + u = "user" + p = "password" + s = "server" line = "-s {} -u {} -p {}".format(s, u, p) magic.session_started = True magic._do_not_call_change_endpoint(line) + @with_setup(_setup, _teardown) def test_info(): magic._print_endpoint_info = print_info_mock = MagicMock() @@ -208,10 +227,15 @@ def test_info(): magic.info(line) - print_info_mock.assert_called_once_with(session_info, spark_controller.get_session_id_for_client.return_value) - spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name) + print_info_mock.assert_called_once_with( + session_info, spark_controller.get_session_id_for_client.return_value + ) + spark_controller.get_session_id_for_client.assert_called_once_with( + magic.session_name + ) + + _assert_magic_successful_event_emitted_once("info") - _assert_magic_successful_event_emitted_once('info') @with_setup(_setup, _teardown) def test_info_without_active_session(): @@ -224,22 +248,25 @@ def test_info_without_active_session(): print_info_mock.assert_called_once_with(session_info, None) - _assert_magic_successful_event_emitted_once('info') + _assert_magic_successful_event_emitted_once("info") + @with_setup(_setup, _teardown) def test_info_with_cell_content(): magic._print_endpoint_info = print_info_mock = MagicMock() line = "" session_info = ["1", "2"] - spark_controller.get_all_sessions_endpoint_info = MagicMock(return_value=session_info) + spark_controller.get_all_sessions_endpoint_info = MagicMock( + return_value=session_info + ) error_msg = "Cell body for %%info magic must be empty; got 'howdy' instead" - magic.info(line, cell='howdy') + magic.info(line, cell="howdy") print_info_mock.assert_not_called() assert_equals(ipython_display.send_error.call_count, 1) spark_controller.get_session_id_for_client.assert_not_called() - _assert_magic_failure_event_emitted_once('info', BadUserDataException(error_msg)) + _assert_magic_failure_event_emitted_once("info", BadUserDataException(error_msg)) @with_setup(_setup, _teardown) @@ -247,7 +274,9 @@ def test_info_with_argument(): magic._print_endpoint_info = print_info_mock = MagicMock() line = "hey" session_info = ["1", "2"] - spark_controller.get_all_sessions_endpoint_info = MagicMock(return_value=session_info) + spark_controller.get_all_sessions_endpoint_info = MagicMock( + return_value=session_info + ) magic.info(line) @@ -260,24 +289,38 @@ def test_info_with_argument(): def test_info_unexpected_exception(): magic._print_endpoint_info = MagicMock() line = "" - spark_controller.get_all_sessions_endpoint = MagicMock(side_effect=ValueError('utter failure')) + spark_controller.get_all_sessions_endpoint = MagicMock( + side_effect=ValueError("utter failure") + ) magic.info(line) - _assert_magic_failure_event_emitted_once('info', spark_controller.get_all_sessions_endpoint.side_effect) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG - .format(spark_controller.get_all_sessions_endpoint.side_effect)) + _assert_magic_failure_event_emitted_once( + "info", spark_controller.get_all_sessions_endpoint.side_effect + ) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format( + spark_controller.get_all_sessions_endpoint.side_effect + ) + ) @with_setup(_setup, _teardown) def test_info_expected_exception(): magic._print_endpoint_info = MagicMock() line = "" - spark_controller.get_all_sessions_endpoint = MagicMock(side_effect=SqlContextNotFoundException('utter failure')) + spark_controller.get_all_sessions_endpoint = MagicMock( + side_effect=SqlContextNotFoundException("utter failure") + ) magic.info(line) - _assert_magic_failure_event_emitted_once('info', spark_controller.get_all_sessions_endpoint.side_effect) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG - .format(spark_controller.get_all_sessions_endpoint.side_effect)) + _assert_magic_failure_event_emitted_once( + "info", spark_controller.get_all_sessions_endpoint.side_effect + ) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format( + spark_controller.get_all_sessions_endpoint.side_effect + ) + ) @with_setup(_setup, _teardown) @@ -285,7 +328,7 @@ def test_help(): magic.help("") assert_equals(ipython_display.html.call_count, 1) - _assert_magic_successful_event_emitted_once('help') + _assert_magic_successful_event_emitted_once("help") @with_setup(_setup, _teardown) @@ -295,7 +338,7 @@ def test_help_with_cell_content(): assert_equals(ipython_display.send_error.call_count, 1) assert_equals(ipython_display.html.call_count, 0) - _assert_magic_failure_event_emitted_once('help', BadUserDataException(msg)) + _assert_magic_failure_event_emitted_once("help", BadUserDataException(msg)) @with_setup(_setup, _teardown) @@ -313,7 +356,7 @@ def test_logs(): magic.logs(line) ipython_display.write.assert_called_once_with("No logs yet.") - _assert_magic_successful_event_emitted_once('logs') + _assert_magic_successful_event_emitted_once("logs") ipython_display.write.reset_mock() @@ -333,7 +376,7 @@ def test_logs_with_cell_content(): magic.logs(line, cell="BOOP") assert_equals(ipython_display.send_error.call_count, 1) - _assert_magic_failure_event_emitted_once('logs', BadUserDataException(msg)) + _assert_magic_failure_event_emitted_once("logs", BadUserDataException(msg)) @with_setup(_setup, _teardown) @@ -350,12 +393,17 @@ def test_logs_unexpected_exception(): magic.session_started = True - spark_controller.get_logs = MagicMock(side_effect=SyntaxError('There was some sort of error')) + spark_controller.get_logs = MagicMock( + side_effect=SyntaxError("There was some sort of error") + ) magic.logs(line) spark_controller.get_logs.assert_called_once_with() - _assert_magic_failure_event_emitted_once('logs', spark_controller.get_logs.side_effect) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG - .format(spark_controller.get_logs.side_effect)) + _assert_magic_failure_event_emitted_once( + "logs", spark_controller.get_logs.side_effect + ) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format(spark_controller.get_logs.side_effect) + ) @with_setup(_setup, _teardown) @@ -364,12 +412,17 @@ def test_logs_expected_exception(): magic.session_started = True - spark_controller.get_logs = MagicMock(side_effect=LivyUnexpectedStatusException('There was some sort of error')) + spark_controller.get_logs = MagicMock( + side_effect=LivyUnexpectedStatusException("There was some sort of error") + ) magic.logs(line) spark_controller.get_logs.assert_called_once_with() - _assert_magic_failure_event_emitted_once('logs', spark_controller.get_logs.side_effect) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG - .format(spark_controller.get_logs.side_effect)) + _assert_magic_failure_event_emitted_once( + "logs", spark_controller.get_logs.side_effect + ) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format(spark_controller.get_logs.side_effect) + ) @with_setup(_setup, _teardown) @@ -379,26 +432,30 @@ def test_configure(): # Session not started conf.override_all({}) - magic.configure('', '{"extra": "yes"}') + magic.configure("", '{"extra": "yes"}') assert_equals(conf.session_configs(), {"extra": "yes"}) - _assert_magic_successful_event_emitted_once('configure') + _assert_magic_successful_event_emitted_once("configure") magic.info.assert_called_once_with("") # Session started - no -f magic.session_started = True conf.override_all({}) - magic.configure('', "{\"extra\": \"yes\"}") + magic.configure("", '{"extra": "yes"}') assert_equals(conf.session_configs(), {}) assert_equals(ipython_display.send_error.call_count, 1) # Session started - with -f magic.info.reset_mock() conf.override_all({}) - magic.configure("-f", "{\"extra\": \"yes\"}") + magic.configure("-f", '{"extra": "yes"}') assert_equals(conf.session_configs(), {"extra": "yes"}) spark_controller.delete_session_by_name.assert_called_once_with(magic.session_name) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK, "extra": "yes"}) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK, "extra": "yes"}, + ) magic.info.assert_called_once_with("") @@ -406,31 +463,45 @@ def test_configure(): def test_configure_unexpected_exception(): magic.info = MagicMock() - magic._override_session_settings = MagicMock(side_effect=ValueError('help')) - magic.configure('', '{"extra": "yes"}') - _assert_magic_failure_event_emitted_once('configure', magic._override_session_settings.side_effect) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG\ - .format(magic._override_session_settings.side_effect)) + magic._override_session_settings = MagicMock(side_effect=ValueError("help")) + magic.configure("", '{"extra": "yes"}') + _assert_magic_failure_event_emitted_once( + "configure", magic._override_session_settings.side_effect + ) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format( + magic._override_session_settings.side_effect + ) + ) @with_setup(_setup, _teardown) def test_configure_expected_exception(): magic.info = MagicMock() - magic._override_session_settings = MagicMock(side_effect=BadUserDataException('help')) - magic.configure('', '{"extra": "yes"}') - _assert_magic_failure_event_emitted_once('configure', magic._override_session_settings.side_effect) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG\ - .format(magic._override_session_settings.side_effect)) + magic._override_session_settings = MagicMock( + side_effect=BadUserDataException("help") + ) + magic.configure("", '{"extra": "yes"}') + _assert_magic_failure_event_emitted_once( + "configure", magic._override_session_settings.side_effect + ) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format( + magic._override_session_settings.side_effect + ) + ) @with_setup(_setup, _teardown) def test_configure_cant_parse_object_as_json(): magic.info = MagicMock() - magic._override_session_settings = MagicMock(side_effect=BadUserDataException('help')) - magic.configure('', "I CAN'T PARSE THIS AS JSON") - _assert_magic_successful_event_emitted_once('configure') + magic._override_session_settings = MagicMock( + side_effect=BadUserDataException("help") + ) + magic.configure("", "I CAN'T PARSE THIS AS JSON") + _assert_magic_successful_event_emitted_once("configure") assert_equals(ipython_display.send_error.call_count, 1) @@ -443,6 +514,7 @@ def test_get_session_settings(): assert magic.get_session_settings("something -f", True) == "something" assert magic.get_session_settings("something", True) is None + @with_setup(_setup, _teardown) def test_send_to_spark_with_non_empty_cell_error(): line = "-i input -n name -t str" @@ -455,6 +527,7 @@ def test_send_to_spark_with_non_empty_cell_error(): assert_equals(ipython_display.send_error.call_count, 1) + @with_setup(_setup, _teardown) def test_send_to_spark_with_no_i_param_error(): line = "-n name -t str" @@ -467,6 +540,7 @@ def test_send_to_spark_with_no_i_param_error(): assert_equals(ipython_display.send_error.call_count, 1) + @with_setup(_setup, _teardown) def test_send_to_spark_ok(): line = "-i input -n name -t str" @@ -477,21 +551,32 @@ def test_send_to_spark_ok(): magic.send_to_spark(line, cell) assert ipython_display.write.called - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) spark_controller.run_command.assert_called_once_with(Command(cell), None) + @with_setup(_setup, _teardown) def test_spark(): line = "" cell = "some spark code" - spark_controller.run_command = MagicMock(return_value=(True, line, constants.MIMETYPE_TEXT_PLAIN)) + spark_controller.run_command = MagicMock( + return_value=(True, line, constants.MIMETYPE_TEXT_PLAIN) + ) magic.spark(line, cell) ipython_display.write.assert_called_once_with(line) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) spark_controller.run_command.assert_called_once_with(Command(cell), None) @@ -510,13 +595,21 @@ def test_spark_with_argument(): def test_spark_error(): line = "" cell = "some spark code" - spark_controller.run_command = MagicMock(return_value=(False, line, constants.MIMETYPE_TEXT_PLAIN)) + spark_controller.run_command = MagicMock( + return_value=(False, line, constants.MIMETYPE_TEXT_PLAIN) + ) magic.spark(line, cell) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG.format(line)) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format(line) + ) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) spark_controller.run_command.assert_called_once_with(Command(cell), None) @@ -538,37 +631,45 @@ def test_spark_failed_session_start(): def test_spark_unexpected_exception(): line = "" cell = "some spark code" - spark_controller.run_command = MagicMock(side_effect=Exception('oups')) + spark_controller.run_command = MagicMock(side_effect=Exception("oups")) magic.spark(line, cell) spark_controller.run_command.assert_called_once_with(Command(cell), None) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG - .format(spark_controller.run_command.side_effect)) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format(spark_controller.run_command.side_effect) + ) @with_setup(_setup, _teardown) def test_spark_expected_exception(): line = "" cell = "some spark code" - spark_controller.run_command = MagicMock(side_effect=SessionManagementException('oups')) + spark_controller.run_command = MagicMock( + side_effect=SessionManagementException("oups") + ) magic.spark(line, cell) spark_controller.run_command.assert_called_once_with(Command(cell), None) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG - .format(spark_controller.run_command.side_effect)) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format(spark_controller.run_command.side_effect) + ) @with_setup(_setup, _teardown) @raises(SparkStatementException) def test_spark_fatal_spark_statement_exception(): - conf.override_all({ - "all_errors_are_fatal": True, - }) + conf.override_all( + { + "all_errors_are_fatal": True, + } + ) line = "" cell = "some spark code" - spark_controller.run_command = MagicMock(side_effect=SparkStatementException('Oh no!')) + spark_controller.run_command = MagicMock( + side_effect=SparkStatementException("Oh no!") + ) magic.spark(line, cell) @@ -577,28 +678,33 @@ def test_spark_fatal_spark_statement_exception(): def test_spark_unexpected_exception_in_storing(): line = "-o var_name" cell = "some spark code" - side_effect = [(True, 'ok', constants.MIMETYPE_TEXT_PLAIN), Exception('oups')] + side_effect = [(True, "ok", constants.MIMETYPE_TEXT_PLAIN), Exception("oups")] spark_controller.run_command = MagicMock(side_effect=side_effect) magic.spark(line, cell) assert_equals(spark_controller.run_command.call_count, 2) spark_controller.run_command.assert_any_call(Command(cell), None) - ipython_display.send_error.assert_called_with(constants.INTERNAL_ERROR_MSG - .format(side_effect[1])) + ipython_display.send_error.assert_called_with( + constants.INTERNAL_ERROR_MSG.format(side_effect[1]) + ) @with_setup(_setup, _teardown) def test_spark_expected_exception_in_storing(): line = "-o var_name" cell = "some spark code" - side_effect = [(True, 'ok', constants.MIMETYPE_TEXT_PLAIN), SessionManagementException('oups')] + side_effect = [ + (True, "ok", constants.MIMETYPE_TEXT_PLAIN), + SessionManagementException("oups"), + ] spark_controller.run_command = MagicMock(side_effect=side_effect) magic.spark(line, cell) assert spark_controller.run_command.call_count == 2 spark_controller.run_command.assert_any_call(Command(cell), None) - ipython_display.send_error.assert_called_with(constants.EXPECTED_ERROR_MSG - .format(side_effect[1])) + ipython_display.send_error.assert_called_with( + constants.EXPECTED_ERROR_MSG.format(side_effect[1]) + ) @with_setup(_setup, _teardown) @@ -608,7 +714,9 @@ def test_spark_sample_options(): magic.execute_spark = MagicMock() ret = magic.spark(line, cell) - magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None, True) + magic.execute_spark.assert_called_once_with( + cell, "var_name", "sample", 142, 0.3, None, True + ) @with_setup(_setup, _teardown) @@ -618,7 +726,9 @@ def test_spark_false_coerce(): magic.execute_spark = MagicMock() ret = magic.spark(line, cell) - magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None, False) + magic.execute_spark.assert_called_once_with( + cell, "var_name", "sample", 142, 0.3, None, False + ) @with_setup(_setup, _teardown) @@ -629,9 +739,15 @@ def test_sql_without_output(): magic.sql(line, cell) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) - magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, None, False, None) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) + magic.execute_sqlquery.assert_called_once_with( + cell, None, None, None, None, None, False, None + ) @with_setup(_setup, _teardown) @@ -642,33 +758,45 @@ def test_sql_with_output(): magic.sql(line, cell) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) - magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) + magic.execute_sqlquery.assert_called_once_with( + cell, None, None, None, None, "my_var", False, None + ) @with_setup(_setup, _teardown) def test_sql_exception(): line = "-o my_var" cell = "some spark code" - magic.execute_sqlquery = MagicMock(side_effect=ValueError('HAHAHAHAH')) + magic.execute_sqlquery = MagicMock(side_effect=ValueError("HAHAHAHAH")) magic.sql(line, cell) - magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG - .format(magic.execute_sqlquery.side_effect)) + magic.execute_sqlquery.assert_called_once_with( + cell, None, None, None, None, "my_var", False, None + ) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format(magic.execute_sqlquery.side_effect) + ) @with_setup(_setup, _teardown) def test_sql_expected_exception(): line = "-o my_var" cell = "some spark code" - magic.execute_sqlquery = MagicMock(side_effect=HttpClientException('HAHAHAHAH')) + magic.execute_sqlquery = MagicMock(side_effect=HttpClientException("HAHAHAHAH")) magic.sql(line, cell) - magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG - .format(magic.execute_sqlquery.side_effect)) + magic.execute_sqlquery.assert_called_once_with( + cell, None, None, None, None, "my_var", False, None + ) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format(magic.execute_sqlquery.side_effect) + ) @with_setup(_setup, _teardown) @@ -692,9 +820,15 @@ def test_sql_quiet(): ret = magic.sql(line, cell) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) - magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "Output", True, None) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) + magic.execute_sqlquery.assert_called_once_with( + cell, None, None, None, None, "Output", True, None + ) @with_setup(_setup, _teardown) @@ -705,9 +839,15 @@ def test_sql_sample_options(): ret = magic.sql(line, cell) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) - magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True, True) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) + magic.execute_sqlquery.assert_called_once_with( + cell, "sample", 142, 0.3, None, None, True, True + ) @with_setup(_setup, _teardown) @@ -718,9 +858,15 @@ def test_sql_false_coerce(): ret = magic.sql(line, cell) - spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_PYSPARK}) - magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True, False) + spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_PYSPARK}, + ) + magic.execute_sqlquery.assert_called_once_with( + cell, "sample", 142, 0.3, None, None, True, False + ) @with_setup(_setup, _teardown) @@ -733,7 +879,7 @@ def test_cleanup_without_force(): magic.cleanup(line, cell) - _assert_magic_successful_event_emitted_once('cleanup') + _assert_magic_successful_event_emitted_once("cleanup") assert_equals(ipython_display.send_error.call_count, 1) assert_equals(spark_controller.cleanup_endpoint.call_count, 0) @@ -749,7 +895,7 @@ def test_cleanup_with_force(): magic.cleanup(line, cell) - _assert_magic_successful_event_emitted_once('cleanup') + _assert_magic_successful_event_emitted_once("cleanup") spark_controller.cleanup_endpoint.assert_called_once_with(magic.endpoint) spark_controller.delete_session_by_name.assert_called_once_with(magic.session_name) @@ -766,7 +912,7 @@ def test_cleanup_with_cell_content(): magic.cleanup(line, cell) assert_equals(ipython_display.send_error.call_count, 1) - _assert_magic_failure_event_emitted_once('cleanup', BadUserDataException(msg)) + _assert_magic_failure_event_emitted_once("cleanup", BadUserDataException(msg)) @with_setup(_setup, _teardown) @@ -774,13 +920,20 @@ def test_cleanup_exception(): line = "-f" cell = "" magic.session_started = True - spark_controller.cleanup_endpoint = MagicMock(side_effect=ArithmeticError('DIVISION BY ZERO OH NO')) + spark_controller.cleanup_endpoint = MagicMock( + side_effect=ArithmeticError("DIVISION BY ZERO OH NO") + ) magic.cleanup(line, cell) - _assert_magic_failure_event_emitted_once('cleanup', spark_controller.cleanup_endpoint.side_effect) + _assert_magic_failure_event_emitted_once( + "cleanup", spark_controller.cleanup_endpoint.side_effect + ) spark_controller.cleanup_endpoint.assert_called_once_with(magic.endpoint) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG - .format(spark_controller.cleanup_endpoint.side_effect)) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format( + spark_controller.cleanup_endpoint.side_effect + ) + ) @with_setup(_setup, _teardown) @@ -793,7 +946,7 @@ def test_delete_without_force(): magic.delete(line, cell) - _assert_magic_successful_event_emitted_once('delete') + _assert_magic_successful_event_emitted_once("delete") assert_equals(ipython_display.send_error.call_count, 1) assert_equals(spark_controller.delete_session_by_id.call_count, 0) @@ -809,7 +962,7 @@ def test_delete_without_session_id(): magic.delete(line, cell) - _assert_magic_successful_event_emitted_once('delete') + _assert_magic_successful_event_emitted_once("delete") assert_equals(ipython_display.send_error.call_count, 1) assert_equals(spark_controller.delete_session_by_id.call_count, 0) @@ -825,7 +978,7 @@ def test_delete_with_force_same_session(): magic.delete(line, cell) - _assert_magic_successful_event_emitted_once('delete') + _assert_magic_successful_event_emitted_once("delete") assert_equals(ipython_display.send_error.call_count, 1) assert_equals(spark_controller.delete_session_by_id.call_count, 0) @@ -842,10 +995,14 @@ def test_delete_with_force_none_session(): magic.delete(line, cell) - _assert_magic_successful_event_emitted_once('delete') + _assert_magic_successful_event_emitted_once("delete") - spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name) - spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id) + spark_controller.get_session_id_for_client.assert_called_once_with( + magic.session_name + ) + spark_controller.delete_session_by_id.assert_called_once_with( + magic.endpoint, session_id + ) @with_setup(_setup, _teardown) @@ -860,7 +1017,7 @@ def test_delete_with_cell_content(): magic.delete(line, cell) - _assert_magic_failure_event_emitted_once('delete', BadUserDataException(msg)) + _assert_magic_failure_event_emitted_once("delete", BadUserDataException(msg)) assert_equals(ipython_display.send_error.call_count, 1) @@ -875,10 +1032,14 @@ def test_delete_with_force_different_session(): magic.delete(line, cell) - _assert_magic_successful_event_emitted_once('delete') + _assert_magic_successful_event_emitted_once("delete") - spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name) - spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id) + spark_controller.get_session_id_for_client.assert_called_once_with( + magic.session_name + ) + spark_controller.delete_session_by_id.assert_called_once_with( + magic.endpoint, session_id + ) @with_setup(_setup, _teardown) @@ -887,15 +1048,26 @@ def test_delete_exception(): session_id = 0 line = "-f -s {}".format(session_id) cell = "" - spark_controller.delete_session_by_id = MagicMock(side_effect=DataFrameParseException('wow')) + spark_controller.delete_session_by_id = MagicMock( + side_effect=DataFrameParseException("wow") + ) spark_controller.get_session_id_for_client = MagicMock() magic.delete(line, cell) - _assert_magic_failure_event_emitted_once('delete', spark_controller.delete_session_by_id.side_effect) - spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name) - spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id) - ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG - .format(spark_controller.delete_session_by_id.side_effect)) + _assert_magic_failure_event_emitted_once( + "delete", spark_controller.delete_session_by_id.side_effect + ) + spark_controller.get_session_id_for_client.assert_called_once_with( + magic.session_name + ) + spark_controller.delete_session_by_id.assert_called_once_with( + magic.endpoint, session_id + ) + ipython_display.send_error.assert_called_once_with( + constants.INTERNAL_ERROR_MSG.format( + spark_controller.delete_session_by_id.side_effect + ) + ) @with_setup(_setup, _teardown) @@ -904,15 +1076,26 @@ def test_delete_expected_exception(): session_id = 0 line = "-f -s {}".format(session_id) cell = "" - spark_controller.delete_session_by_id = MagicMock(side_effect=LivyClientTimeoutException('wow')) + spark_controller.delete_session_by_id = MagicMock( + side_effect=LivyClientTimeoutException("wow") + ) spark_controller.get_session_id_for_client = MagicMock() magic.delete(line, cell) - _assert_magic_failure_event_emitted_once('delete', spark_controller.delete_session_by_id.side_effect) - spark_controller.get_session_id_for_client.assert_called_once_with(magic.session_name) - spark_controller.delete_session_by_id.assert_called_once_with(magic.endpoint, session_id) - ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG - .format(spark_controller.delete_session_by_id.side_effect)) + _assert_magic_failure_event_emitted_once( + "delete", spark_controller.delete_session_by_id.side_effect + ) + spark_controller.get_session_id_for_client.assert_called_once_with( + magic.session_name + ) + spark_controller.delete_session_by_id.assert_called_once_with( + magic.endpoint, session_id + ) + ipython_display.send_error.assert_called_once_with( + constants.EXPECTED_ERROR_MSG.format( + spark_controller.delete_session_by_id.side_effect + ) + ) @with_setup(_setup, _teardown) @@ -925,8 +1108,12 @@ def test_start_session_displays_fatal_error_when_session_throws(): magic._do_not_call_start_session("") - magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_SPARK}) + magic.spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_SPARK}, + ) assert magic.fatal_error assert magic.fatal_error_message == conf.fatal_error_suggestion().format(str(e)) @@ -945,8 +1132,12 @@ def test_start_session_when_retry_fatal_error_is_not_allowed_by_default(): magic._do_not_call_start_session("") # call add_session once and call send_error twice to show the error message - magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_SPARK}) + magic.spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_SPARK}, + ) assert_equals(magic.ipython_display.send_error.call_count, 2) @@ -961,8 +1152,12 @@ def test_retry_start_session_when_retry_fatal_error_is_allowed(): # first session creation - failed session_created = magic._do_not_call_start_session("") - magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_SPARK}) + magic.spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_SPARK}, + ) assert not session_created assert magic.fatal_error assert magic.fatal_error_message == conf.fatal_error_suggestion().format(str(e)) @@ -970,13 +1165,17 @@ def test_retry_start_session_when_retry_fatal_error_is_allowed(): # retry session creation - successful magic.spark_controller.add_session = MagicMock() session_created = magic._do_not_call_start_session("") - magic.spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False, - {"kind": constants.SESSION_KIND_SPARK}) + magic.spark_controller.add_session.assert_called_once_with( + magic.session_name, + magic.endpoint, + False, + {"kind": constants.SESSION_KIND_SPARK}, + ) print(session_created) assert session_created assert magic.session_started assert not magic.fatal_error - assert magic.fatal_error_message == u'' + assert magic.fatal_error_message == "" # show error message only once assert magic.ipython_display.send_error.call_count == 1 @@ -993,30 +1192,42 @@ def test_allow_retry_fatal(): def test_kernel_magics_names(): """The magics machinery in IPython depends on the docstrings and method names matching up correctly""" - assert_equals(magic.help.__name__, 'help') - assert_equals(magic.local.__name__, 'local') - assert_equals(magic.info.__name__, 'info') - assert_equals(magic.logs.__name__, 'logs') - assert_equals(magic.configure.__name__, 'configure') - assert_equals(magic.spark.__name__, 'spark') - assert_equals(magic.sql.__name__, 'sql') - assert_equals(magic.cleanup.__name__, 'cleanup') - assert_equals(magic.delete.__name__, 'delete') + assert_equals(magic.help.__name__, "help") + assert_equals(magic.local.__name__, "local") + assert_equals(magic.info.__name__, "info") + assert_equals(magic.logs.__name__, "logs") + assert_equals(magic.configure.__name__, "configure") + assert_equals(magic.spark.__name__, "spark") + assert_equals(magic.sql.__name__, "sql") + assert_equals(magic.cleanup.__name__, "cleanup") + assert_equals(magic.delete.__name__, "delete") def _assert_magic_successful_event_emitted_once(name): magic._generate_uuid.assert_called_once_with() - spark_events.emit_magic_execution_start_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK, - magic._generate_uuid.return_value) - spark_events.emit_magic_execution_end_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK, - magic._generate_uuid.return_value, True, - '', '') + spark_events.emit_magic_execution_start_event.assert_called_once_with( + name, constants.SESSION_KIND_PYSPARK, magic._generate_uuid.return_value + ) + spark_events.emit_magic_execution_end_event.assert_called_once_with( + name, + constants.SESSION_KIND_PYSPARK, + magic._generate_uuid.return_value, + True, + "", + "", + ) def _assert_magic_failure_event_emitted_once(name, error): magic._generate_uuid.assert_called_once_with() - spark_events.emit_magic_execution_start_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK, - magic._generate_uuid.return_value) - spark_events.emit_magic_execution_end_event.assert_called_once_with(name, constants.SESSION_KIND_PYSPARK, - magic._generate_uuid.return_value, False, - error.__class__.__name__, str(error)) + spark_events.emit_magic_execution_start_event.assert_called_once_with( + name, constants.SESSION_KIND_PYSPARK, magic._generate_uuid.return_value + ) + spark_events.emit_magic_execution_end_event.assert_called_once_with( + name, + constants.SESSION_KIND_PYSPARK, + magic._generate_uuid.return_value, + False, + error.__class__.__name__, + str(error), + ) diff --git a/sparkmagic/sparkmagic/tests/test_kernels.py b/sparkmagic/sparkmagic/tests/test_kernels.py index 0a2c7a0a2..6867659ab 100644 --- a/sparkmagic/sparkmagic/tests/test_kernels.py +++ b/sparkmagic/sparkmagic/tests/test_kernels.py @@ -26,18 +26,15 @@ def test_pyspark_kernel_configs(): kernel = TestPyparkKernel() assert kernel.session_language == LANG_PYTHON - assert kernel.implementation == 'PySpark' + assert kernel.implementation == "PySpark" assert kernel.language == LANG_PYTHON - assert kernel.language_version == '0.1' + assert kernel.language_version == "0.1" assert kernel.language_info == { - 'name': 'pyspark', - 'mimetype': 'text/x-python', - 'codemirror_mode': { - 'name': 'python', - 'version': 3 - }, - 'file_extension': '.py', - 'pygments_lexer': 'python3', + "name": "pyspark", + "mimetype": "text/x-python", + "codemirror_mode": {"name": "python", "version": 3}, + "file_extension": ".py", + "pygments_lexer": "python3", } @@ -46,15 +43,15 @@ def test_spark_kernel_configs(): assert kernel.session_language == LANG_SCALA - assert kernel.implementation == 'Spark' + assert kernel.implementation == "Spark" assert kernel.language == LANG_SCALA - assert kernel.language_version == '0.1' + assert kernel.language_version == "0.1" assert kernel.language_info == { - 'name': 'scala', - 'mimetype': 'text/x-scala', - 'pygments_lexer': 'scala', - 'file_extension': '.sc', - 'codemirror_mode': 'text/x-scala', + "name": "scala", + "mimetype": "text/x-scala", + "pygments_lexer": "scala", + "file_extension": ".sc", + "codemirror_mode": "text/x-scala", } @@ -63,13 +60,13 @@ def test_sparkr_kernel_configs(): assert kernel.session_language == LANG_R - assert kernel.implementation == 'SparkR' + assert kernel.implementation == "SparkR" assert kernel.language == LANG_R - assert kernel.language_version == '0.1' + assert kernel.language_version == "0.1" assert kernel.language_info == { - 'name': 'sparkR', - 'mimetype': 'text/x-rsrc', - 'pygments_lexer': 'r', - 'file_extension': '.r', - 'codemirror_mode': 'text/x-rsrc' + "name": "sparkR", + "mimetype": "text/x-rsrc", + "pygments_lexer": "r", + "file_extension": ".r", + "codemirror_mode": "text/x-rsrc", } diff --git a/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py b/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py index 9d4c77091..4f0beb87a 100644 --- a/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py +++ b/sparkmagic/sparkmagic/tests/test_livyreliablehttpclient.py @@ -13,7 +13,7 @@ def test_post_statement(): http_client = MagicMock() livy_client = LivyReliableHttpClient(http_client, None) - data = {"adlfj":"sadflkjsdf"} + data = {"adlfj": "sadflkjsdf"} out = livy_client.post_statement(100, data) assert_equals(out, http_client.post.return_value.json.return_value) http_client.post.assert_called_once_with("/sessions/100/statements", [201], data) @@ -32,7 +32,9 @@ def test_cancel_statement(): livy_client = LivyReliableHttpClient(http_client, None) out = livy_client.cancel_statement(100, 104) assert_equals(out, http_client.post.return_value.json.return_value) - http_client.post.assert_called_once_with("/sessions/100/statements/104/cancel", [200], {}) + http_client.post.assert_called_once_with( + "/sessions/100/statements/104/cancel", [200], {} + ) def test_get_sessions(): @@ -46,7 +48,7 @@ def test_get_sessions(): def test_post_session(): http_client = MagicMock() livy_client = LivyReliableHttpClient(http_client, None) - properties = {"adlfj":"sadflkjsdf", 1: [2,3,4,5]} + properties = {"adlfj": "sadflkjsdf", 1: [2, 3, 4, 5]} out = livy_client.post_session(properties) assert_equals(out, http_client.post.return_value.json.return_value) http_client.post.assert_called_once_with("/sessions", [201], properties) @@ -77,7 +79,7 @@ def test_get_all_session_logs(): def test_custom_headers(): custom_headers = {"header1": "value1"} - overrides = { conf.custom_headers.__name__: custom_headers } + overrides = {conf.custom_headers.__name__: custom_headers} conf.override_all(overrides) endpoint = Endpoint("http://url.com", None) client = LivyReliableHttpClient.from_endpoint(endpoint) @@ -113,5 +115,5 @@ def test_retry_policy(): def _override_policy(policy): - overrides = { conf.retry_policy.__name__: policy } + overrides = {conf.retry_policy.__name__: policy} conf.override_all(overrides) diff --git a/sparkmagic/sparkmagic/tests/test_livysession.py b/sparkmagic/sparkmagic/tests/test_livysession.py index 111362bc8..2407daf2c 100644 --- a/sparkmagic/sparkmagic/tests/test_livysession.py +++ b/sparkmagic/sparkmagic/tests/test_livysession.py @@ -4,8 +4,12 @@ import sparkmagic.utils.constants as constants import sparkmagic.utils.configuration as conf -from sparkmagic.livyclientlib.exceptions import LivyClientTimeoutException, LivyUnexpectedStatusException,\ - BadUserDataException, SqlContextNotFoundException +from sparkmagic.livyclientlib.exceptions import ( + LivyClientTimeoutException, + LivyUnexpectedStatusException, + BadUserDataException, + SqlContextNotFoundException, +) from sparkmagic.livyclientlib.livysession import LivySession @@ -27,27 +31,43 @@ class TestLivySession(object): pi_result = "Pi is roughly 3.14336" - session_create_json = json.loads('{"id":0,"state":"starting","kind":"spark","log":[]}') - resource_limit_json = json.loads('{"id":0,"state":"starting","kind":"spark","log":[' - '"Queue\'s AM resource limit exceeded."]}') - ready_sessions_json = json.loads('{"id":0,"state":"idle","kind":"spark","log":[""]}') - recovering_sessions_json = json.loads('{"id":0,"state":"recovering","kind":"spark","log":[""]}') - error_sessions_json = json.loads('{"id":0,"state":"error","kind":"spark","log":[""]}') + session_create_json = json.loads( + '{"id":0,"state":"starting","kind":"spark","log":[]}' + ) + resource_limit_json = json.loads( + '{"id":0,"state":"starting","kind":"spark","log":[' + '"Queue\'s AM resource limit exceeded."]}' + ) + ready_sessions_json = json.loads( + '{"id":0,"state":"idle","kind":"spark","log":[""]}' + ) + recovering_sessions_json = json.loads( + '{"id":0,"state":"recovering","kind":"spark","log":[""]}' + ) + error_sessions_json = json.loads( + '{"id":0,"state":"error","kind":"spark","log":[""]}' + ) busy_sessions_json = json.loads('{"id":0,"state":"busy","kind":"spark","log":[""]}') post_statement_json = json.loads('{"id":0,"state":"running","output":null}') waiting_statement_json = json.loads('{"id":0,"state":"waiting","output":null}') running_statement_json = json.loads('{"id":0,"state":"running","output":null}') - ready_statement_json = json.loads('{"id":0,"state":"available","output":{"status":"ok",' - '"execution_count":0,"data":{"text/plain":"Pi is roughly 3.14336"}}}') - ready_statement_null_output_json = json.loads('{"id":0,"state":"available","output":null}') - ready_statement_failed_json = json.loads('{"id":0,"state":"available","output":{"status":"error",' - '"evalue":"error","traceback":"error"}}') + ready_statement_json = json.loads( + '{"id":0,"state":"available","output":{"status":"ok",' + '"execution_count":0,"data":{"text/plain":"Pi is roughly 3.14336"}}}' + ) + ready_statement_null_output_json = json.loads( + '{"id":0,"state":"available","output":null}' + ) + ready_statement_failed_json = json.loads( + '{"id":0,"state":"available","output":{"status":"error",' + '"evalue":"error","traceback":"error"}}' + ) log_json = json.loads('{"id":6,"from":0,"total":212,"log":["hi","hi"]}') def __init__(self): self.http_client = None self.spark_events = None - + self.get_statement_responses = [] self.post_statement_responses = [] self.get_session_responses = [] @@ -78,16 +98,19 @@ def _next_session_response_post(self, *args): self.post_session_responses = self.post_session_responses[1:] return val - def _create_session(self, kind=constants.SESSION_KIND_SPARK, session_id=-1, - heartbeat_timeout=60): + def _create_session( + self, kind=constants.SESSION_KIND_SPARK, session_id=-1, heartbeat_timeout=60 + ): ipython_display = MagicMock() - session = LivySession(self.http_client, - {"kind": kind}, - ipython_display, - session_id, - self.spark_events, - heartbeat_timeout, - self.heartbeat_thread) + session = LivySession( + self.http_client, + {"kind": kind}, + ipython_display, + session_id, + self.spark_events, + heartbeat_timeout, + self.heartbeat_thread, + ) return session def _create_session_with_fixed_get_response(self, get_session_json): @@ -109,22 +132,22 @@ def test_constructor_starts_with_existing_session(self): assert session.id == session_id assert session._heartbeat_thread is None - assert constants.LIVY_HEARTBEAT_TIMEOUT_PARAM not in list(session.properties.keys()) - + assert constants.LIVY_HEARTBEAT_TIMEOUT_PARAM not in list( + session.properties.keys() + ) + def test_constructor_starts_heartbeat_with_existing_session(self): - conf.override_all({ - "heartbeat_refresh_seconds": 0.1 - }) + conf.override_all({"heartbeat_refresh_seconds": 0.1}) session_id = 1 session = self._create_session(session_id=session_id) conf.override_all({}) - + assert session.id == session_id assert self.heartbeat_thread.daemon self.heartbeat_thread.start.assert_called_once_with() assert not session._heartbeat_thread is None - assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM ] > 0 - + assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM] > 0 + def test_start_with_heartbeat(self): self.http_client.post_session.return_value = self.session_create_json self.http_client.get_session.return_value = self.ready_sessions_json @@ -132,12 +155,12 @@ def test_start_with_heartbeat(self): session = self._create_session() session.start() - + assert self.heartbeat_thread.daemon self.heartbeat_thread.start.assert_called_once_with() assert not session._heartbeat_thread is None - assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM ] > 0 - + assert session.properties[constants.LIVY_HEARTBEAT_TIMEOUT_PARAM] > 0 + def test_start_with_heartbeat_calls_only_once(self): self.http_client.post_session.return_value = self.session_create_json self.http_client.get_session.return_value = self.ready_sessions_json @@ -151,7 +174,7 @@ def test_start_with_heartbeat_calls_only_once(self): assert self.heartbeat_thread.daemon self.heartbeat_thread.start.assert_called_once_with() assert not session._heartbeat_thread is None - + def test_delete_with_heartbeat(self): self.http_client.post_session.return_value = self.session_create_json self.http_client.get_session.return_value = self.ready_sessions_json @@ -160,9 +183,9 @@ def test_delete_with_heartbeat(self): session = self._create_session() session.start() heartbeat_thread = session._heartbeat_thread - + session.delete() - + self.heartbeat_thread.stop.assert_called_once_with() assert session._heartbeat_thread is None @@ -195,7 +218,9 @@ def test_start_scala_starts_session(self): assert_equals(kind, session.kind) assert_equals("idle", session.status) assert_equals(0, session.id) - self.http_client.post_session.assert_called_with({"kind": "spark", "heartbeatTimeoutInSecond": 60}) + self.http_client.post_session.assert_called_with( + {"kind": "spark", "heartbeatTimeoutInSecond": 60} + ) def test_start_r_starts_session(self): self.http_client.post_session.return_value = self.session_create_json @@ -209,7 +234,9 @@ def test_start_r_starts_session(self): assert_equals(kind, session.kind) assert_equals("idle", session.status) assert_equals(0, session.id) - self.http_client.post_session.assert_called_with({"kind": "sparkr", "heartbeatTimeoutInSecond": 60}) + self.http_client.post_session.assert_called_with( + {"kind": "sparkr", "heartbeatTimeoutInSecond": 60} + ) def test_start_python_starts_session(self): self.http_client.post_session.return_value = self.session_create_json @@ -223,7 +250,9 @@ def test_start_python_starts_session(self): assert_equals(kind, session.kind) assert_equals("idle", session.status) assert_equals(0, session.id) - self.http_client.post_session.assert_called_with({"kind": "pyspark", "heartbeatTimeoutInSecond": 60}) + self.http_client.post_session.assert_called_with( + {"kind": "pyspark", "heartbeatTimeoutInSecond": 60} + ) def test_start_passes_in_all_properties(self): self.http_client.post_session.return_value = self.session_create_json @@ -257,12 +286,14 @@ def test_status_recovering(self): Ensure 'recovering' state is supported: we go from recovering to idle. """ self.http_client.post_session.return_value = self.session_create_json + def get_session(i, calls=[]): if not calls: calls.append(1) return self.recovering_sessions_json else: return self.ready_sessions_json + self.http_client.get_session.side_effect = get_session self.http_client.get_statement.return_value = self.ready_statement_json session = self._create_session() @@ -284,16 +315,18 @@ def test_logs_gets_latest_logs(self): def test_wait_for_idle_returns_when_in_state(self): self.http_client.post_session.return_value = self.session_create_json - self.get_session_responses = [self.ready_sessions_json, - self.ready_sessions_json, - self.busy_sessions_json, - self.ready_sessions_json] + self.get_session_responses = [ + self.ready_sessions_json, + self.ready_sessions_json, + self.busy_sessions_json, + self.ready_sessions_json, + ] self.http_client.get_session.side_effect = self._next_session_response_get self.http_client.get_statement.return_value = self.ready_statement_json session = self._create_session() session.get_row_html = MagicMock() - session.get_row_html.return_value = u"""""" + session.get_row_html.return_value = """""" session.start() @@ -304,17 +337,19 @@ def test_wait_for_idle_returns_when_in_state(self): def test_wait_for_idle_prints_resource_limit_message(self): self.http_client.post_session.return_value = self.session_create_json - self.get_session_responses = [self.resource_limit_json, - self.ready_sessions_json, - self.ready_sessions_json, - self.ready_sessions_json] + self.get_session_responses = [ + self.resource_limit_json, + self.ready_sessions_json, + self.ready_sessions_json, + self.ready_sessions_json, + ] self.http_client.get_session.side_effect = self._next_session_response_get self.http_client.get_statement.return_value = self.ready_statement_json self.http_client.get_all_session_logs.return_value = self.log_json session = self._create_session() session.get_row_html = MagicMock() - session.get_row_html.return_value = u"""""" + session.get_row_html.return_value = """""" session.start() @@ -324,16 +359,18 @@ def test_wait_for_idle_prints_resource_limit_message(self): @raises(LivyUnexpectedStatusException) def test_wait_for_idle_throws_when_in_final_status(self): self.http_client.post_session.return_value = self.session_create_json - self.get_session_responses = [self.ready_sessions_json, - self.busy_sessions_json, - self.busy_sessions_json, - self.error_sessions_json] + self.get_session_responses = [ + self.ready_sessions_json, + self.busy_sessions_json, + self.busy_sessions_json, + self.error_sessions_json, + ] self.http_client.get_session.side_effect = self._next_session_response_get self.http_client.get_all_session_logs.return_value = self.log_json session = self._create_session() session.get_row_html = MagicMock() - session.get_row_html.return_value = u"""""" + session.get_row_html.return_value = """""" session.start() @@ -342,17 +379,19 @@ def test_wait_for_idle_throws_when_in_final_status(self): @raises(LivyClientTimeoutException) def test_wait_for_idle_times_out(self): self.http_client.post_session.return_value = self.session_create_json - self.get_session_responses = [self.ready_sessions_json, - self.ready_sessions_json, - self.busy_sessions_json, - self.busy_sessions_json, - self.ready_sessions_json] + self.get_session_responses = [ + self.ready_sessions_json, + self.ready_sessions_json, + self.busy_sessions_json, + self.busy_sessions_json, + self.ready_sessions_json, + ] self.http_client.get_session.side_effect = self._next_session_response_get self.http_client.get_statement.return_value = self.ready_statement_json session = self._create_session() session.get_row_html = MagicMock() - session.get_row_html.return_value = u"""""" + session.get_row_html.return_value = """""" session.start() @@ -395,9 +434,12 @@ def test_start_emits_start_end_session(self): session = self._create_session(kind=kind) session.start() - self.spark_events.emit_session_creation_start_event.assert_called_once_with(session.guid, kind) + self.spark_events.emit_session_creation_start_event.assert_called_once_with( + session.guid, kind + ) self.spark_events.emit_session_creation_end_event.assert_called_once_with( - session.guid, kind, session.id, session.status, True, "", "") + session.guid, kind, session.id, session.status, True, "", "" + ) def test_start_emits_start_end_failed_session_when_bad_status(self): self.http_client.post_session.side_effect = ValueError @@ -412,9 +454,12 @@ def test_start_emits_start_end_failed_session_when_bad_status(self): except ValueError: pass - self.spark_events.emit_session_creation_start_event.assert_called_once_with(session.guid, kind) + self.spark_events.emit_session_creation_start_event.assert_called_once_with( + session.guid, kind + ) self.spark_events.emit_session_creation_end_event.assert_called_once_with( - session.guid, kind, session.id, session.status, False, "ValueError", "") + session.guid, kind, session.id, session.status, False, "ValueError", "" + ) def test_start_emits_start_end_failed_session_when_wait_for_idle_throws(self): self.http_client.post_session.return_value = self.session_create_json @@ -430,9 +475,12 @@ def test_start_emits_start_end_failed_session_when_wait_for_idle_throws(self): except ValueError: pass - self.spark_events.emit_session_creation_start_event.assert_called_once_with(session.guid, kind) + self.spark_events.emit_session_creation_start_event.assert_called_once_with( + session.guid, kind + ) self.spark_events.emit_session_creation_end_event.assert_called_once_with( - session.guid, kind, session.id, session.status, False, "ValueError", "") + session.guid, kind, session.id, session.status, False, "ValueError", "" + ) def test_delete_session_emits_start_end(self): self.http_client.post_session.return_value = self.session_create_json @@ -449,9 +497,17 @@ def test_delete_session_emits_start_end(self): assert_equals(session.id, -1) self.spark_events.emit_session_deletion_start_event.assert_called_once_with( - session.guid, session.kind, end_id, end_status) + session.guid, session.kind, end_id, end_status + ) self.spark_events.emit_session_deletion_end_event.assert_called_once_with( - session.guid, session.kind, end_id, constants.DEAD_SESSION_STATUS, True, "", "") + session.guid, + session.kind, + end_id, + constants.DEAD_SESSION_STATUS, + True, + "", + "", + ) def test_delete_session_emits_start_failed_end_when_delete_throws(self): self.http_client.delete_session.side_effect = ValueError @@ -472,9 +528,11 @@ def test_delete_session_emits_start_failed_end_when_delete_throws(self): pass self.spark_events.emit_session_deletion_start_event.assert_called_once_with( - session.guid, session.kind, end_id, end_status) + session.guid, session.kind, end_id, end_status + ) self.spark_events.emit_session_deletion_end_event.assert_called_once_with( - session.guid, session.kind, end_id, end_status, False, "ValueError", "") + session.guid, session.kind, end_id, end_status, False, "ValueError", "" + ) def test_delete_session_emits_start_failed_end_when_in_bad_state(self): self.http_client.get_session.return_value = self.ready_sessions_json @@ -491,9 +549,17 @@ def test_delete_session_emits_start_failed_end_when_in_bad_state(self): assert_equals(0, session.ipython_display.send_error.call_count) self.spark_events.emit_session_deletion_start_event.assert_called_once_with( - session.guid, session.kind, end_id, end_status) + session.guid, session.kind, end_id, end_status + ) self.spark_events.emit_session_deletion_end_event.assert_called_once_with( - session.guid, session.kind, end_id, constants.DEAD_SESSION_STATUS, True, "", "") + session.guid, + session.kind, + end_id, + constants.DEAD_SESSION_STATUS, + True, + "", + "", + ) def test_get_empty_app_id(self): self._verify_get_app_id("null", None, 7) @@ -502,20 +568,22 @@ def test_get_missing_app_id(self): self._verify_get_app_id(None, None, 7) def test_get_normal_app_id(self): - self._verify_get_app_id("\"app_id_123\"", "app_id_123", 6) + self._verify_get_app_id('"app_id_123"', "app_id_123", 6) def test_get_empty_driver_log_url(self): self._verify_get_driver_log_url("null", None) def test_get_normal_driver_log_url(self): - self._verify_get_driver_log_url("\"http://example.com\"", "http://example.com") + self._verify_get_driver_log_url('"http://example.com"', "http://example.com") def test_missing_app_info_get_driver_log_url(self): self._verify_get_driver_log_url_json(self.ready_sessions_json, None) - + def _verify_get_app_id(self, mock_app_id, expected_app_id, expected_call_count): - mock_field = ",\"appId\":" + mock_app_id if mock_app_id is not None else "" - get_session_json = json.loads('{"id":0,"state":"idle","output":null%s,"log":""}' % mock_field) + mock_field = ',"appId":' + mock_app_id if mock_app_id is not None else "" + get_session_json = json.loads( + '{"id":0,"state":"idle","output":null%s,"log":""}' % mock_field + ) session = self._create_session_with_fixed_get_response(get_session_json) app_id = session.get_app_id() @@ -524,8 +592,14 @@ def _verify_get_app_id(self, mock_app_id, expected_app_id, expected_call_count): assert_equals(expected_call_count, self.http_client.get_session.call_count) def _verify_get_driver_log_url(self, mock_driver_log_url, expected_url): - mock_field = "\"driverLogUrl\":" + mock_driver_log_url if mock_driver_log_url is not None else "" - session_json = json.loads('{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field) + mock_field = ( + '"driverLogUrl":' + mock_driver_log_url + if mock_driver_log_url is not None + else "" + ) + session_json = json.loads( + '{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field + ) self._verify_get_driver_log_url_json(session_json, expected_url) def _verify_get_driver_log_url_json(self, get_session_json, expected_url): @@ -540,14 +614,18 @@ def test_get_empty_spark_ui_url(self): self._verify_get_spark_ui_url("null", None) def test_get_normal_spark_ui_url(self): - self._verify_get_spark_ui_url("\"http://example.com\"", "http://example.com") + self._verify_get_spark_ui_url('"http://example.com"', "http://example.com") def test_missing_app_info_get_spark_ui_url(self): self._verify_get_spark_ui_url_json(self.ready_sessions_json, None) def _verify_get_spark_ui_url(self, mock_spark_ui_url, expected_url): - mock_field = "\"sparkUiUrl\":" + mock_spark_ui_url if mock_spark_ui_url is not None else "" - session_json = json.loads('{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field) + mock_field = ( + '"sparkUiUrl":' + mock_spark_ui_url if mock_spark_ui_url is not None else "" + ) + session_json = json.loads( + '{"id":0,"state":"idle","output":null,"appInfo":{%s},"log":""}' % mock_field + ) self._verify_get_spark_ui_url_json(session_json, expected_url) def _verify_get_spark_ui_url_json(self, get_session_json, expected_url): @@ -565,38 +643,48 @@ def test_get_row_html(self): session1.get_spark_ui_url = MagicMock() session1.get_driver_log_url = MagicMock() session1.get_user = MagicMock() - session1.get_app_id.return_value = 'app1234' + session1.get_app_id.return_value = "app1234" session1.status = constants.IDLE_SESSION_STATUS - session1.get_spark_ui_url.return_value = 'https://microsoft.com/sparkui' - session1.get_driver_log_url.return_value = 'https://microsoft.com/driverlog' - session1.get_user.return_value = 'userTest' + session1.get_spark_ui_url.return_value = "https://microsoft.com/sparkui" + session1.get_driver_log_url.return_value = "https://microsoft.com/driverlog" + session1.get_user.return_value = "userTest" html1 = session1.get_row_html(1) - assert_equals(html1, u"""""") + assert_equals( + html1, + """""", + ) session_id2 = 3 - session2 = self._create_session(kind=constants.SESSION_KIND_PYSPARK, - session_id=session_id2) + session2 = self._create_session( + kind=constants.SESSION_KIND_PYSPARK, session_id=session_id2 + ) session2.get_app_id = MagicMock() session2.get_spark_ui_url = MagicMock() session2.get_driver_log_url = MagicMock() session2.get_user = MagicMock() - session2.get_app_id.return_value = 'app5069' + session2.get_app_id.return_value = "app5069" session2.status = constants.BUSY_SESSION_STATUS session2.get_spark_ui_url.return_value = None session2.get_driver_log_url.return_value = None - session2.get_user.return_value = 'userTest2' + session2.get_user.return_value = "userTest2" html2 = session2.get_row_html(1) - assert_equals(html2, u"""""") + assert_equals( + html2, + """""", + ) def test_link(self): - url = u"https://microsoft.com" - assert_equals(LivySession.get_html_link(u'Link', url), u"""Link""") + url = "https://microsoft.com" + assert_equals( + LivySession.get_html_link("Link", url), + """Link""", + ) url = None - assert_equals(LivySession.get_html_link(u'Link', url), u"") + assert_equals(LivySession.get_html_link("Link", url), "") def test_spark_session_available(self): self.http_client.post_session.return_value = self.session_create_json @@ -604,13 +692,15 @@ def test_spark_session_available(self): self.http_client.get_statement.return_value = self.ready_statement_json session = self._create_session() session.start() - assert_equals(session.sql_context_variable_name,"spark") + assert_equals(session.sql_context_variable_name, "spark") def test_sql_context_available(self): self.http_client.post_session.return_value = self.session_create_json self.http_client.get_session.return_value = self.ready_sessions_json - self.get_statement_responses = [self.ready_statement_failed_json, - self.ready_statement_json] + self.get_statement_responses = [ + self.ready_statement_failed_json, + self.ready_statement_json, + ] self.http_client.get_statement.side_effect = self._next_statement_response_get session = self._create_session() session.start() diff --git a/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py b/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py index 6a3d8750a..9b52571b6 100644 --- a/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py +++ b/sparkmagic/sparkmagic/tests/test_pd_data_coerce.py @@ -5,8 +5,10 @@ def test_no_coercing(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': u'random', u'temp_diff': u'0adsf'}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": "random", "temp_diff": "0adsf"}, + ] desired_df = pd.DataFrame(records) df = pd.DataFrame(records) @@ -16,8 +18,10 @@ def test_no_coercing(): def test_date_coercing(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': u'6/1/13', u'temp_diff': u'0adsf'}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": "6/1/13", "temp_diff": "0adsf"}, + ] desired_df = pd.DataFrame(records) desired_df["date"] = pd.to_datetime(desired_df["date"]) @@ -28,8 +32,10 @@ def test_date_coercing(): def test_date_coercing_none_values(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': None, u'temp_diff': u'0adsf'}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": None, "temp_diff": "0adsf"}, + ] desired_df = pd.DataFrame(records) desired_df["date"] = pd.to_datetime(desired_df["date"]) @@ -40,9 +46,11 @@ def test_date_coercing_none_values(): def test_date_none_values_and_no_coercing(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': None, u'temp_diff': u'0adsf'}, - {u'buildingID': 1, u'date': u'adsf', u'temp_diff': u'0adsf'}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": None, "temp_diff": "0adsf"}, + {"buildingID": 1, "date": "adsf", "temp_diff": "0adsf"}, + ] desired_df = pd.DataFrame(records) df = pd.DataFrame(records) @@ -52,8 +60,10 @@ def test_date_none_values_and_no_coercing(): def test_numeric_coercing(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': u'adsf', u'temp_diff': u'0'}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": "adsf", "temp_diff": "0"}, + ] desired_df = pd.DataFrame(records) desired_df["temp_diff"] = pd.to_numeric(desired_df["temp_diff"]) @@ -64,8 +74,10 @@ def test_numeric_coercing(): def test_numeric_coercing_none_values(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': u'asdf', u'temp_diff': None}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": "asdf", "temp_diff": None}, + ] desired_df = pd.DataFrame(records) desired_df["temp_diff"] = pd.to_numeric(desired_df["temp_diff"]) @@ -76,9 +88,11 @@ def test_numeric_coercing_none_values(): def test_numeric_none_values_and_no_coercing(): - records = [{u'buildingID': 0, u'date': u'6/1/13', u'temp_diff': u'12'}, - {u'buildingID': 1, u'date': u'asdf', u'temp_diff': None}, - {u'buildingID': 1, u'date': u'adsf', u'temp_diff': u'0asdf'}] + records = [ + {"buildingID": 0, "date": "6/1/13", "temp_diff": "12"}, + {"buildingID": 1, "date": "asdf", "temp_diff": None}, + {"buildingID": 1, "date": "adsf", "temp_diff": "0asdf"}, + ] desired_df = pd.DataFrame(records) df = pd.DataFrame(records) @@ -125,18 +139,18 @@ def test_df_dict_does_not_throw(): def test_overflow_coercing(): - records = [{'_c0':'12345678901'}] + records = [{"_c0": "12345678901"}] desired_df = pd.DataFrame(records) - desired_df['_c0'] = pd.to_numeric(desired_df['_c0']) + desired_df["_c0"] = pd.to_numeric(desired_df["_c0"]) df = pd.DataFrame(records) coerce_pandas_df_to_numeric_datetime(df) assert_frame_equal(desired_df, df) - + def test_all_null_columns(): - records = [{'_c0':'12345', 'nulla': None}, {'_c0':'12345', 'nulla': None}] + records = [{"_c0": "12345", "nulla": None}, {"_c0": "12345", "nulla": None}] desired_df = pd.DataFrame(records) - desired_df['_c0'] = pd.to_numeric(desired_df['_c0']) + desired_df["_c0"] = pd.to_numeric(desired_df["_c0"]) df = pd.DataFrame(records) coerce_pandas_df_to_numeric_datetime(df) assert_frame_equal(desired_df, df) diff --git a/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py b/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py index 61851b41b..de0d9dcc6 100644 --- a/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py +++ b/sparkmagic/sparkmagic/tests/test_reliablehttpclient.py @@ -2,7 +2,14 @@ # Distributed under the terms of the Modified BSD License. from mock import patch, PropertyMock, MagicMock -from nose.tools import raises, assert_equals, with_setup, assert_is_not_none, assert_false, assert_true +from nose.tools import ( + raises, + assert_equals, + with_setup, + assert_is_not_none, + assert_false, + assert_true, +) import requests from requests_kerberos.kerberos_ import HTTPKerberosAuth, REQUIRED, OPTIONAL from sparkmagic.auth.basic import Basic @@ -60,7 +67,7 @@ def test_compose_url(): @with_setup(_setup, _teardown) def test_get(): - with patch('requests.Session.get') as patched_get: + with patch("requests.Session.get") as patched_get: type(patched_get.return_value).status_code = 200 client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -73,7 +80,7 @@ def test_get(): @raises(HttpClientException) @with_setup(_setup, _teardown) def test_get_throws(): - with patch('requests.Session.get') as patched_get: + with patch("requests.Session.get") as patched_get: type(patched_get.return_value).status_code = 500 client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -88,7 +95,7 @@ def test_get_will_retry(): retry_policy.should_retry.return_value = True retry_policy.seconds_to_sleep.return_value = 0.01 - with patch('requests.Session.get') as patched_get: + with patch("requests.Session.get") as patched_get: # When we call assert_equals in this unit test, the side_effect is executed. # So, the last status_code should be repeated. sequential_values = [500, 200, 200] @@ -106,7 +113,7 @@ def test_get_will_retry(): @with_setup(_setup, _teardown) def test_post(): - with patch('requests.Session.post') as patched_post: + with patch("requests.Session.post") as patched_post: type(patched_post.return_value).status_code = 200 client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -119,7 +126,7 @@ def test_post(): @raises(HttpClientException) @with_setup(_setup, _teardown) def test_post_throws(): - with patch('requests.Session.post') as patched_post: + with patch("requests.Session.post") as patched_post: type(patched_post.return_value).status_code = 500 client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -134,7 +141,7 @@ def test_post_will_retry(): retry_policy.should_retry.return_value = True retry_policy.seconds_to_sleep.return_value = 0.01 - with patch('requests.Session.post') as patched_post: + with patch("requests.Session.post") as patched_post: # When we call assert_equals in this unit test, the side_effect is executed. # So, the last status_code should be repeated. sequential_values = [500, 200, 200] @@ -152,7 +159,7 @@ def test_post_will_retry(): @with_setup(_setup, _teardown) def test_delete(): - with patch('requests.Session.delete') as patched_delete: + with patch("requests.Session.delete") as patched_delete: type(patched_delete.return_value).status_code = 200 client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -165,7 +172,7 @@ def test_delete(): @raises(HttpClientException) @with_setup(_setup, _teardown) def test_delete_throws(): - with patch('requests.Session.delete') as patched_delete: + with patch("requests.Session.delete") as patched_delete: type(patched_delete.return_value).status_code = 500 client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -180,7 +187,7 @@ def test_delete_will_retry(): retry_policy.should_retry.return_value = True retry_policy.seconds_to_sleep.return_value = 0.01 - with patch('requests.Session.delete') as patched_delete: + with patch("requests.Session.delete") as patched_delete: # When we call assert_equals in this unit test, the side_effect is executed. # So, the last status_code should be repeated. sequential_values = [500, 200, 200] @@ -203,7 +210,7 @@ def test_will_retry_error_no(): retry_policy.should_retry.return_value = False retry_policy.seconds_to_sleep.return_value = 0.01 - with patch('requests.Session.get') as patched_get: + with patch("requests.Session.get") as patched_get: patched_get.side_effect = requests.exceptions.ConnectionError() client = ReliableHttpClient(endpoint, {}, retry_policy) @@ -219,8 +226,8 @@ def test_basic_auth_check_auth(): endpoint = Endpoint("http://url.com", basic_auth) client = ReliableHttpClient(endpoint, {}, retry_policy) assert isinstance(client._auth, Basic) - assert hasattr(client._auth, 'username') - assert hasattr(client._auth, 'password') + assert hasattr(client._auth, "username") + assert hasattr(client._auth, "password") assert_equals(client._auth.username, endpoint.auth.username) assert_equals(client._auth.password, endpoint.auth.password) @@ -238,24 +245,21 @@ def test_kerberos_auth_check_auth(): client = ReliableHttpClient(endpoint, {}, retry_policy) assert_is_not_none(client._auth) assert isinstance(client._auth, HTTPKerberosAuth) - assert hasattr(client._auth, 'mutual_authentication') + assert hasattr(client._auth, "mutual_authentication") assert_equals(client._auth.mutual_authentication, REQUIRED) @with_setup(_setup, _teardown) def test_kerberos_auth_custom_configuration(): - custom_kerberos_conf = { - "mutual_authentication": OPTIONAL, - "force_preemptive": True - } - overrides = { conf.kerberos_auth_configuration.__name__: custom_kerberos_conf } + custom_kerberos_conf = {"mutual_authentication": OPTIONAL, "force_preemptive": True} + overrides = {conf.kerberos_auth_configuration.__name__: custom_kerberos_conf} conf.override_all(overrides) kerberos_auth = Kerberos() endpoint = Endpoint("http://url.com", kerberos_auth) client = ReliableHttpClient(endpoint, {}, retry_policy) assert_is_not_none(client._auth) assert isinstance(client._auth, HTTPKerberosAuth) - assert hasattr(client._auth, 'mutual_authentication') + assert hasattr(client._auth, "mutual_authentication") assert_equals(client._auth.mutual_authentication, OPTIONAL) - assert hasattr(client._auth, 'force_preemptive') + assert hasattr(client._auth, "force_preemptive") assert_equals(client._auth.force_preemptive, True) diff --git a/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py b/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py index 396232990..cc62c27ca 100644 --- a/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py +++ b/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py @@ -3,7 +3,12 @@ import sparkmagic.utils.configuration as conf from sparkmagic.utils.utils import parse_argstring_or_throw, initialize_auth -from sparkmagic.utils.constants import EXPECTED_ERROR_MSG, MIMETYPE_TEXT_PLAIN, NO_AUTH, AUTH_BASIC +from sparkmagic.utils.constants import ( + EXPECTED_ERROR_MSG, + MIMETYPE_TEXT_PLAIN, + NO_AUTH, + AUTH_BASIC, +) from sparkmagic.magics.remotesparkmagics import RemoteSparkMagics from sparkmagic.livyclientlib.command import Command from sparkmagic.livyclientlib.endpoint import Endpoint @@ -51,20 +56,21 @@ def test_info_endpoint_command_parses(): magic.spark(command) - print_info_mock.assert_called_once_with(None,1234) + print_info_mock.assert_called_once_with(None, 1234) @with_setup(_setup, _teardown) def test_info_command_exception(): - print_info_mock = MagicMock(side_effect=LivyClientTimeoutException('OHHHHHOHOHOHO')) + print_info_mock = MagicMock(side_effect=LivyClientTimeoutException("OHHHHHOHOHOHO")) magic._print_local_info = print_info_mock command = "info" magic.spark(command) print_info_mock.assert_called_once_with() - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(print_info_mock.side_effect)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(print_info_mock.side_effect) + ) @with_setup(_setup, _teardown) @@ -80,8 +86,12 @@ def test_add_sessions_command_parses(): magic.spark(line) args = parse_argstring_or_throw(RemoteSparkMagics.spark, line) - add_sessions_mock.assert_called_once_with("name", Endpoint("http://url.com", initialize_auth(args)), - False, {"kind": "pyspark"}) + add_sessions_mock.assert_called_once_with( + "name", + Endpoint("http://url.com", initialize_auth(args)), + False, + {"kind": "pyspark"}, + ) # Skip and scala - upper case add_sessions_mock = MagicMock() spark_controller.add_session = add_sessions_mock @@ -94,8 +104,12 @@ def test_add_sessions_command_parses(): magic.spark(line) args = parse_argstring_or_throw(RemoteSparkMagics.spark, line) args.auth = NO_AUTH - add_sessions_mock.assert_called_once_with("name", Endpoint("http://location:port", initialize_auth(args)), - True, {"kind": "spark"}) + add_sessions_mock.assert_called_once_with( + "name", + Endpoint("http://location:port", initialize_auth(args)), + True, + {"kind": "spark"}, + ) @with_setup(_setup, _teardown) @@ -106,21 +120,25 @@ def test_add_sessions_command_parses_kerberos(): command = "add" name = "-s name" language = "-l python" - connection_string = "-u http://url.com -t {}".format('Kerberos') + connection_string = "-u http://url.com -t {}".format("Kerberos") line = " ".join([command, name, language, connection_string]) magic.spark(line) args = parse_argstring_or_throw(RemoteSparkMagics.spark, line) auth_instance = initialize_auth(args) - - add_sessions_mock.assert_called_once_with("name", Endpoint("http://url.com", initialize_auth(args)), - False, {"kind": "pyspark"}) + + add_sessions_mock.assert_called_once_with( + "name", + Endpoint("http://url.com", initialize_auth(args)), + False, + {"kind": "pyspark"}, + ) assert_equals(auth_instance.url, "http://url.com") @with_setup(_setup, _teardown) def test_add_sessions_command_exception(): # Do not skip and python - add_sessions_mock = MagicMock(side_effect=BadUserDataException('hehe')) + add_sessions_mock = MagicMock(side_effect=BadUserDataException("hehe")) spark_controller.add_session = add_sessions_mock command = "add" name = "-s name" @@ -130,16 +148,21 @@ def test_add_sessions_command_exception(): magic.spark(line) args = parse_argstring_or_throw(RemoteSparkMagics.spark, line) - add_sessions_mock.assert_called_once_with("name", Endpoint("http://url.com", initialize_auth(args)), - False, {"kind": "pyspark"}) - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(add_sessions_mock.side_effect)) + add_sessions_mock.assert_called_once_with( + "name", + Endpoint("http://url.com", initialize_auth(args)), + False, + {"kind": "pyspark"}, + ) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(add_sessions_mock.side_effect) + ) @with_setup(_setup, _teardown) def test_add_sessions_command_extra_properties(): conf.override_all({}) - magic.spark("config", "{\"extra\": \"yes\"}") + magic.spark("config", '{"extra": "yes"}') assert conf.session_configs() == {"extra": "yes"} add_sessions_mock = MagicMock() @@ -153,8 +176,12 @@ def test_add_sessions_command_extra_properties(): magic.spark(line) args = parse_argstring_or_throw(RemoteSparkMagics.spark, line) args.auth = NO_AUTH - add_sessions_mock.assert_called_once_with("name", Endpoint("http://livyendpoint.com", initialize_auth(args)), - False, {"kind": "spark", "extra": "yes"}) + add_sessions_mock.assert_called_once_with( + "name", + Endpoint("http://livyendpoint.com", initialize_auth(args)), + False, + {"kind": "spark", "extra": "yes"}, + ) conf.override_all({}) @@ -176,13 +203,14 @@ def test_delete_sessions_command_parses(): @with_setup(_setup, _teardown) def test_delete_sessions_command_exception(): - mock_method = MagicMock(side_effect=LivyUnexpectedStatusException('FEEEEEELINGS')) + mock_method = MagicMock(side_effect=LivyUnexpectedStatusException("FEEEEEELINGS")) spark_controller.delete_session_by_name = mock_method command = "delete -s name" magic.spark(command) mock_method.assert_called_once_with("name") - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(mock_method.side_effect)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(mock_method.side_effect) + ) @with_setup(_setup, _teardown) @@ -198,14 +226,17 @@ def test_cleanup_command_parses(): @with_setup(_setup, _teardown) def test_cleanup_command_exception(): - mock_method = MagicMock(side_effect=SessionManagementException('Livy did something VERY BAD')) + mock_method = MagicMock( + side_effect=SessionManagementException("Livy did something VERY BAD") + ) spark_controller.cleanup = mock_method line = "cleanup" magic.spark(line) mock_method.assert_called_once_with() - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(mock_method.side_effect)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(mock_method.side_effect) + ) @with_setup(_setup, _teardown) @@ -232,7 +263,9 @@ def test_bad_command_writes_error(): magic.spark(line) - ipython_display.send_error.assert_called_once_with("Subcommand '{}' not found. {}".format(line, usage)) + ipython_display.send_error.assert_called_once_with( + "Subcommand '{}' not found. {}".format(line, usage) + ) @with_setup(_setup, _teardown) @@ -270,13 +303,15 @@ def test_run_cell_command_writes_to_err(): run_cell_method.assert_called_once_with(Command(cell), name) assert result is None - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG.format(result_value)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(result_value) + ) @with_setup(_setup, _teardown) def test_run_cell_command_exception(): run_cell_method = MagicMock() - run_cell_method.side_effect = HttpClientException('meh') + run_cell_method.side_effect = HttpClientException("meh") spark_controller.run_command = run_cell_method command = "-s" @@ -288,8 +323,9 @@ def test_run_cell_command_exception(): run_cell_method.assert_called_once_with(Command(cell), name) assert result is None - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(run_cell_method.side_effect)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(run_cell_method.side_effect) + ) @with_setup(_setup, _teardown) @@ -307,8 +343,9 @@ def test_run_spark_command_parses(): result = magic.spark(line, cell) - magic.execute_spark.assert_called_once_with("cell code", - None, "sample", None, None, "sessions_name", None) + magic.execute_spark.assert_called_once_with( + "cell code", None, "sample", None, None, "sessions_name", None + ) @with_setup(_setup, _teardown) @@ -323,13 +360,16 @@ def test_run_spark_command_parses_with_coerce(): method_name = "sample" coer = "--coerce" coerce_value = "True" - line = " ".join([command, name, context, context_name, meth, method_name, coer, coerce_value]) + line = " ".join( + [command, name, context, context_name, meth, method_name, coer, coerce_value] + ) cell = "cell code" result = magic.spark(line, cell) - magic.execute_spark.assert_called_once_with("cell code", - None, "sample", None, None, "sessions_name", True) + magic.execute_spark.assert_called_once_with( + "cell code", None, "sample", None, None, "sessions_name", True + ) @with_setup(_setup, _teardown) @@ -344,13 +384,16 @@ def test_run_spark_command_parses_with_coerce_false(): method_name = "sample" coer = "--coerce" coerce_value = "False" - line = " ".join([command, name, context, context_name, meth, method_name, coer, coerce_value]) + line = " ".join( + [command, name, context, context_name, meth, method_name, coer, coerce_value] + ) cell = "cell code" result = magic.spark(line, cell) - magic.execute_spark.assert_called_once_with("cell code", - None, "sample", None, None, "sessions_name", False) + magic.execute_spark.assert_called_once_with( + "cell code", None, "sample", None, None, "sessions_name", False + ) @with_setup(_setup, _teardown) @@ -365,13 +408,16 @@ def test_run_sql_command_parses_with_coerce_false(): method_name = "sample" coer = "--coerce" coerce_value = "False" - line = " ".join([command, name, context, context_name, meth, method_name, coer, coerce_value]) + line = " ".join( + [command, name, context, context_name, meth, method_name, coer, coerce_value] + ) cell = "cell code" result = magic.spark(line, cell) - magic.execute_sqlquery.assert_called_once_with("cell code", - "sample", None, None, "sessions_name", None, False, False) + magic.execute_sqlquery.assert_called_once_with( + "cell code", "sample", None, None, "sessions_name", None, False, False + ) @with_setup(_setup, _teardown) @@ -386,12 +432,16 @@ def test_run_spark_with_store_command_parses(): method_name = "sample" output = "-o" output_var = "var_name" - line = " ".join([command, name, context, context_name, meth, method_name, output, output_var]) + line = " ".join( + [command, name, context, context_name, meth, method_name, output, output_var] + ) cell = "cell code" result = magic.spark(line, cell) - magic.execute_spark.assert_called_once_with("cell code", - "var_name", "sample", None, None, "sessions_name", None) + magic.execute_spark.assert_called_once_with( + "cell code", "var_name", "sample", None, None, "sessions_name", None + ) + @with_setup(_setup, _teardown) def test_run_spark_with_store_correct_calls(): @@ -409,19 +459,34 @@ def test_run_spark_with_store_correct_calls(): output_var = "var_name" coer = "--coerce" coerce_value = "True" - line = " ".join([command, name, context, context_name, meth, method_name, output, output_var, coer, coerce_value]) + line = " ".join( + [ + command, + name, + context, + context_name, + meth, + method_name, + output, + output_var, + coer, + coerce_value, + ] + ) cell = "cell code" result = magic.spark(line, cell) run_cell_method.assert_any_call(Command(cell), name) - run_cell_method.assert_any_call(SparkStoreCommand(output_var, samplemethod=method_name, coerce=True), name) + run_cell_method.assert_any_call( + SparkStoreCommand(output_var, samplemethod=method_name, coerce=True), name + ) @with_setup(_setup, _teardown) def test_run_spark_command_exception(): run_cell_method = MagicMock() - run_cell_method.side_effect = LivyUnexpectedStatusException('WOW') + run_cell_method.side_effect = LivyUnexpectedStatusException("WOW") spark_controller.run_command = run_cell_method command = "-s" @@ -432,19 +497,23 @@ def test_run_spark_command_exception(): method_name = "sample" output = "-o" output_var = "var_name" - line = " ".join([command, name, context, context_name, meth, method_name, output, output_var]) + line = " ".join( + [command, name, context, context_name, meth, method_name, output, output_var] + ) cell = "cell code" result = magic.spark(line, cell) run_cell_method.assert_any_call(Command(cell), name) - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(run_cell_method.side_effect)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(run_cell_method.side_effect) + ) + @with_setup(_setup, _teardown) def test_run_spark_command_exception_while_storing(): run_cell_method = MagicMock() - exception = LivyUnexpectedStatusException('WOW') + exception = LivyUnexpectedStatusException("WOW") run_cell_method.side_effect = [(True, "", MIMETYPE_TEXT_PLAIN), exception] spark_controller.run_command = run_cell_method @@ -456,17 +525,21 @@ def test_run_spark_command_exception_while_storing(): method_name = "sample" output = "-o" output_var = "var_name" - line = " ".join([command, name, context, context_name, meth, method_name, output, output_var]) + line = " ".join( + [command, name, context, context_name, meth, method_name, output, output_var] + ) cell = "cell code" result = magic.spark(line, cell) run_cell_method.assert_any_call(Command(cell), name) - run_cell_method.assert_any_call(SparkStoreCommand(output_var, samplemethod=method_name), name) + run_cell_method.assert_any_call( + SparkStoreCommand(output_var, samplemethod=method_name), name + ) ipython_display.write.assert_called_once_with("") - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(exception)) - + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(exception) + ) @with_setup(_setup, _teardown) @@ -486,14 +559,16 @@ def test_run_sql_command_parses(): result = magic.spark(line, cell) - run_cell_method.assert_called_once_with(SQLQuery(cell, samplemethod=method_name), name) + run_cell_method.assert_called_once_with( + SQLQuery(cell, samplemethod=method_name), name + ) assert result is not None @with_setup(_setup, _teardown) def test_run_sql_command_exception(): run_cell_method = MagicMock() - run_cell_method.side_effect = LivyUnexpectedStatusException('WOW') + run_cell_method.side_effect = LivyUnexpectedStatusException("WOW") spark_controller.run_sqlquery = run_cell_method command = "-s" @@ -507,9 +582,12 @@ def test_run_sql_command_exception(): result = magic.spark(line, cell) - run_cell_method.assert_called_once_with(SQLQuery(cell, samplemethod=method_name), name) - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(run_cell_method.side_effect)) + run_cell_method.assert_called_once_with( + SQLQuery(cell, samplemethod=method_name), name + ) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(run_cell_method.side_effect) + ) @with_setup(_setup, _teardown) @@ -530,7 +608,9 @@ def test_run_sql_command_knows_how_to_be_quiet(): result = magic.spark(line, cell) - run_cell_method.assert_called_once_with(SQLQuery(cell, samplemethod=method_name), name) + run_cell_method.assert_called_once_with( + SQLQuery(cell, samplemethod=method_name), name + ) assert result is None @@ -555,7 +635,9 @@ def test_logs_subcommand(): @with_setup(_setup, _teardown) def test_logs_exception(): - get_logs_method = MagicMock(side_effect=LivyUnexpectedStatusException('How did this happen?')) + get_logs_method = MagicMock( + side_effect=LivyUnexpectedStatusException("How did this happen?") + ) result_value = "" get_logs_method.return_value = result_value spark_controller.get_logs = get_logs_method @@ -569,5 +651,6 @@ def test_logs_exception(): get_logs_method.assert_called_once_with(name) assert result is None - ipython_display.send_error.assert_called_once_with(EXPECTED_ERROR_MSG - .format(get_logs_method.side_effect)) + ipython_display.send_error.assert_called_once_with( + EXPECTED_ERROR_MSG.format(get_logs_method.side_effect) + ) diff --git a/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py b/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py index f303e369b..afae91935 100644 --- a/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py +++ b/sparkmagic/sparkmagic/tests/test_sendpandasdftosparkcommand.py @@ -6,122 +6,211 @@ from nose.tools import assert_raises, assert_equals from sparkmagic.livyclientlib.command import Command import sparkmagic.utils.constants as constants -from sparkmagic.livyclientlib.sendpandasdftosparkcommand import SendPandasDfToSparkCommand +from sparkmagic.livyclientlib.sendpandasdftosparkcommand import ( + SendPandasDfToSparkCommand, +) + def test_send_to_scala(): - input_variable_name = 'input' - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) - output_variable_name = 'output' + input_variable_name = "input" + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) + output_variable_name = "output" maxrows = 1 - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, maxrows) + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, maxrows + ) sparkcommand._scala_command = MagicMock(return_value=MagicMock()) - sparkcommand.to_command(constants.SESSION_KIND_SPARK, input_variable_name, input_variable_value, output_variable_name) - sparkcommand._scala_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name) + sparkcommand.to_command( + constants.SESSION_KIND_SPARK, + input_variable_name, + input_variable_value, + output_variable_name, + ) + sparkcommand._scala_command.assert_called_with( + input_variable_name, input_variable_value, output_variable_name + ) + def test_send_to_r(): - input_variable_name = 'input' - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) - output_variable_name = 'output' + input_variable_name = "input" + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) + output_variable_name = "output" maxrows = 1 - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, maxrows) + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, maxrows + ) sparkcommand._r_command = MagicMock(return_value=MagicMock()) - sparkcommand.to_command(constants.SESSION_KIND_SPARKR, input_variable_name, input_variable_value, output_variable_name) - sparkcommand._r_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name) + sparkcommand.to_command( + constants.SESSION_KIND_SPARKR, + input_variable_name, + input_variable_value, + output_variable_name, + ) + sparkcommand._r_command.assert_called_with( + input_variable_name, input_variable_value, output_variable_name + ) + def test_send_to_python(): - input_variable_name = 'input' - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) - output_variable_name = 'output' + input_variable_name = "input" + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) + output_variable_name = "output" maxrows = 1 - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, maxrows) + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, maxrows + ) sparkcommand._pyspark_command = MagicMock(return_value=MagicMock()) - sparkcommand.to_command(constants.SESSION_KIND_PYSPARK, input_variable_name, input_variable_value, output_variable_name) - sparkcommand._pyspark_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name) + sparkcommand.to_command( + constants.SESSION_KIND_PYSPARK, + input_variable_name, + input_variable_value, + output_variable_name, + ) + sparkcommand._pyspark_command.assert_called_with( + input_variable_name, input_variable_value, output_variable_name + ) + def test_should_create_a_valid_scala_expression(): input_variable_name = "input" - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) output_variable_name = "output" - pandas_df_jsonized = u'''[{"A":1,"B":2}]''' - expected_scala_code = u''' + pandas_df_jsonized = """[{"A":1,"B":2}]""" + expected_scala_code = ''' val rdd_json_array = spark.sparkContext.makeRDD("""{}""" :: Nil) - val {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name) + val {} = spark.read.json(rdd_json_array)'''.format( + pandas_df_jsonized, output_variable_name + ) + + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, 1 + ) + assert_equals( + sparkcommand._scala_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command(expected_scala_code), + ) - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1) - assert_equals(sparkcommand._scala_command(input_variable_name, input_variable_value, output_variable_name), - Command(expected_scala_code)) def test_should_create_a_valid_r_expression(): input_variable_name = "input" - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) output_variable_name = "output" - pandas_df_jsonized = u'''[{"A":1,"B":2}]''' - expected_r_code = u''' + pandas_df_jsonized = """[{"A":1,"B":2}]""" + expected_r_code = """ fileConn<-file("temporary_pandas_df_sparkmagics.txt") writeLines('{}', fileConn) close(fileConn) {} <- read.json("temporary_pandas_df_sparkmagics.txt") {}.persist() - file.remove("temporary_pandas_df_sparkmagics.txt")'''.format(pandas_df_jsonized, output_variable_name, output_variable_name) + file.remove("temporary_pandas_df_sparkmagics.txt")""".format( + pandas_df_jsonized, output_variable_name, output_variable_name + ) + + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, 1 + ) + assert_equals( + sparkcommand._r_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command(expected_r_code), + ) - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1) - assert_equals(sparkcommand._r_command(input_variable_name, input_variable_value, output_variable_name), - Command(expected_r_code)) def test_should_create_a_valid_python3_expression(): input_variable_name = "input" - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) output_variable_name = "output" - pandas_df_jsonized = u'''[{"A":1,"B":2}]''' + pandas_df_jsonized = """[{"A":1,"B":2}]""" expected_python3_code = SendPandasDfToSparkCommand._python_decode - expected_python3_code += u''' + expected_python3_code += """ json_array = json_loads_byteified('{}') rdd_json_array = spark.sparkContext.parallelize(json_array) - {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name) + {} = spark.read.json(rdd_json_array)""".format( + pandas_df_jsonized, output_variable_name + ) + + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, 1 + ) + assert_equals( + sparkcommand._pyspark_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command(expected_python3_code), + ) - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1) - assert_equals(sparkcommand._pyspark_command(input_variable_name, input_variable_value, output_variable_name), - Command(expected_python3_code)) def test_should_create_a_valid_python2_expression(): input_variable_name = "input" - input_variable_value = pd.DataFrame({'A': [1], 'B' : [2]}) + input_variable_value = pd.DataFrame({"A": [1], "B": [2]}) output_variable_name = "output" - pandas_df_jsonized = u'''[{"A":1,"B":2}]''' + pandas_df_jsonized = """[{"A":1,"B":2}]""" expected_python2_code = SendPandasDfToSparkCommand._python_decode - expected_python2_code += u''' + expected_python2_code += """ json_array = json_loads_byteified('{}') rdd_json_array = spark.sparkContext.parallelize(json_array) - {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name) + {} = spark.read.json(rdd_json_array)""".format( + pandas_df_jsonized, output_variable_name + ) + + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, 1 + ) + assert_equals( + sparkcommand._pyspark_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command(expected_python2_code), + ) - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1) - assert_equals(sparkcommand._pyspark_command(input_variable_name, input_variable_value, output_variable_name), - Command(expected_python2_code)) def test_should_properly_limit_pandas_dataframe(): input_variable_name = "input" max_rows = 1 - input_variable_value = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B' : [5, 6, 7, 8, 9]}) + input_variable_value = pd.DataFrame({"A": [0, 1, 2, 3, 4], "B": [5, 6, 7, 8, 9]}) output_variable_name = "output" - pandas_df_jsonized = u'''[{"A":0,"B":5}]''' #notice we expect json to have dropped all but one row - expected_scala_code = u''' + pandas_df_jsonized = ( + """[{"A":0,"B":5}]""" # notice we expect json to have dropped all but one row + ) + expected_scala_code = ''' val rdd_json_array = spark.sparkContext.makeRDD("""{}""" :: Nil) - val {} = spark.read.json(rdd_json_array)'''.format(pandas_df_jsonized, output_variable_name) + val {} = spark.read.json(rdd_json_array)'''.format( + pandas_df_jsonized, output_variable_name + ) + + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, max_rows + ) + assert_equals( + sparkcommand._scala_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command(expected_scala_code), + ) - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, max_rows) - assert_equals(sparkcommand._scala_command(input_variable_name, input_variable_value, output_variable_name), - Command(expected_scala_code)) def test_should_raise_when_input_is_not_pandas_df(): input_variable_name = "input" input_variable_value = "not a pandas dataframe" output_variable_name = "output" - sparkcommand = SendPandasDfToSparkCommand(input_variable_name, input_variable_value, output_variable_name, 1) - assert_raises(BadUserDataException, sparkcommand.to_command, "spark", input_variable_name, input_variable_value, output_variable_name) + sparkcommand = SendPandasDfToSparkCommand( + input_variable_name, input_variable_value, output_variable_name, 1 + ) + assert_raises( + BadUserDataException, + sparkcommand.to_command, + "spark", + input_variable_name, + input_variable_value, + output_variable_name, + ) diff --git a/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py b/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py index 3d36ceb0c..a0e5c9ef0 100644 --- a/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py +++ b/sparkmagic/sparkmagic/tests/test_sendstringtosparkcommand.py @@ -6,67 +6,140 @@ from sparkmagic.livyclientlib.command import Command import sparkmagic.utils.constants as constants + def test_send_to_scala(): input_variable_name = "input" input_variable_value = "value" output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) sparkcommand._scala_command = MagicMock(return_value=MagicMock()) - sparkcommand.to_command(constants.SESSION_KIND_SPARK, input_variable_name, input_variable_value, output_variable_name) - sparkcommand._scala_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name) + sparkcommand.to_command( + constants.SESSION_KIND_SPARK, + input_variable_name, + input_variable_value, + output_variable_name, + ) + sparkcommand._scala_command.assert_called_with( + input_variable_name, input_variable_value, output_variable_name + ) + def test_send_to_r(): input_variable_name = "input" input_variable_value = "value" output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) sparkcommand._r_command = MagicMock(return_value=MagicMock()) - sparkcommand.to_command(constants.SESSION_KIND_SPARKR, input_variable_name, input_variable_value, output_variable_name) - sparkcommand._r_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name) + sparkcommand.to_command( + constants.SESSION_KIND_SPARKR, + input_variable_name, + input_variable_value, + output_variable_name, + ) + sparkcommand._r_command.assert_called_with( + input_variable_name, input_variable_value, output_variable_name + ) + def test_send_to_pyspark(): input_variable_name = "input" input_variable_value = "value" output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) sparkcommand._pyspark_command = MagicMock(return_value=MagicMock()) - sparkcommand.to_command(constants.SESSION_KIND_PYSPARK, input_variable_name, input_variable_value, output_variable_name) - sparkcommand._pyspark_command.assert_called_with(input_variable_name, input_variable_value, output_variable_name) + sparkcommand.to_command( + constants.SESSION_KIND_PYSPARK, + input_variable_name, + input_variable_value, + output_variable_name, + ) + sparkcommand._pyspark_command.assert_called_with( + input_variable_name, input_variable_value, output_variable_name + ) + def test_to_command_invalid(): input_variable_name = "input" input_variable_value = 42 output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) - assert_raises(BadUserDataException, sparkcommand.to_command, "invalid", input_variable_name, input_variable_value, output_variable_name) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) + assert_raises( + BadUserDataException, + sparkcommand.to_command, + "invalid", + input_variable_name, + input_variable_value, + output_variable_name, + ) + def test_should_raise_when_input_aint_a_string(): input_variable_name = "input" input_variable_value = 42 output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) - assert_raises(BadUserDataException, sparkcommand.to_command, "spark", input_variable_name, input_variable_value, output_variable_name) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) + assert_raises( + BadUserDataException, + sparkcommand.to_command, + "spark", + input_variable_name, + input_variable_value, + output_variable_name, + ) + def test_should_create_a_valid_scala_expression(): input_variable_name = "input" input_variable_value = "value" output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) - assert_equals(sparkcommand._scala_command(input_variable_name, input_variable_value, output_variable_name), - Command(u'var {} = """{}"""'.format(output_variable_name, input_variable_value))) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) + assert_equals( + sparkcommand._scala_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command('var {} = """{}"""'.format(output_variable_name, input_variable_value)), + ) + def test_should_create_a_valid_python_expression(): input_variable_name = "input" input_variable_value = "value" output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) - assert_equals(sparkcommand._pyspark_command(input_variable_name, input_variable_value, output_variable_name), - Command(u'{} = {}'.format(output_variable_name, repr(input_variable_value)))) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) + assert_equals( + sparkcommand._pyspark_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command("{} = {}".format(output_variable_name, repr(input_variable_value))), + ) + def test_should_create_a_valid_r_expression(): input_variable_name = "input" input_variable_value = "value" output_variable_name = "output" - sparkcommand = SendStringToSparkCommand(input_variable_name, input_variable_value, output_variable_name) - assert_equals(sparkcommand._r_command(input_variable_name, input_variable_value, output_variable_name), - Command(u'''assign("{}","{}")'''.format(output_variable_name, input_variable_value))) + sparkcommand = SendStringToSparkCommand( + input_variable_name, input_variable_value, output_variable_name + ) + assert_equals( + sparkcommand._r_command( + input_variable_name, input_variable_value, output_variable_name + ), + Command( + """assign("{}","{}")""".format(output_variable_name, input_variable_value) + ), + ) diff --git a/sparkmagic/sparkmagic/tests/test_sessionmanager.py b/sparkmagic/sparkmagic/tests/test_sessionmanager.py index 1a010ae41..eb666c2d6 100644 --- a/sparkmagic/sparkmagic/tests/test_sessionmanager.py +++ b/sparkmagic/sparkmagic/tests/test_sessionmanager.py @@ -111,7 +111,9 @@ def test_cleanup_all_sessions_on_exit(): client0.delete.assert_called_once_with() client1.delete.assert_called_once_with() - manager.ipython_display.writeln.assert_called_once_with(u"Cleaning up livy sessions on exit is enabled") + manager.ipython_display.writeln.assert_called_once_with( + "Cleaning up livy sessions on exit is enabled" + ) def test_cleanup_all_sessions_on_exit_fails(): @@ -121,7 +123,7 @@ def test_cleanup_all_sessions_on_exit_fails(): conf.override(conf.cleanup_all_sessions_on_exit.__name__, True) client0 = MagicMock() client1 = MagicMock() - client0.delete.side_effect = Exception('Mocked exception for client1.delete') + client0.delete.side_effect = Exception("Mocked exception for client1.delete") manager = get_session_manager() manager.add_session("name0", client0) manager.add_session("name1", client1) @@ -150,7 +152,7 @@ def test_get_session_name_by_id_endpoint(): name = manager.get_session_name_by_id_endpoint(id_to_search, endpoint_to_search) assert_equals(None, name) - + session = MagicMock() type(session).id = PropertyMock(return_value=int(id_to_search)) session.endpoint = endpoint_to_search diff --git a/sparkmagic/sparkmagic/tests/test_sparkcontroller.py b/sparkmagic/sparkmagic/tests/test_sparkcontroller.py index fa3da4263..f877db103 100644 --- a/sparkmagic/sparkmagic/tests/test_sparkcontroller.py +++ b/sparkmagic/sparkmagic/tests/test_sparkcontroller.py @@ -3,7 +3,10 @@ from mock import MagicMock from nose.tools import with_setup, assert_equals, raises from sparkmagic.livyclientlib.endpoint import Endpoint -from sparkmagic.livyclientlib.exceptions import SessionManagementException, HttpClientException +from sparkmagic.livyclientlib.exceptions import ( + SessionManagementException, + HttpClientException, +) from sparkmagic.livyclientlib.sparkcontroller import SparkController client_manager = None @@ -34,9 +37,11 @@ def _setup(): controller.session_manager = client_manager controller.spark_events = spark_events + def _teardown(): pass + @with_setup(_setup, _teardown) def test_add_session(): name = "name" @@ -49,7 +54,9 @@ def test_add_session(): controller.add_session(name, endpoint, False, properties) - controller._livy_session.assert_called_once_with(controller._http_client.return_value, properties, ipython_display) + controller._livy_session.assert_called_once_with( + controller._http_client.return_value, properties, ipython_display + ) controller.session_manager.add_session.assert_called_once_with(name, session) session.start.assert_called_once() @@ -129,10 +136,12 @@ def test_get_client_keys(): @with_setup(_setup, _teardown) def test_get_all_sessions(): http_client = MagicMock() - http_client.get_sessions.return_value = json.loads('{"from":0,"total":3,"sessions":[{"id":0,"state":"idle","kind":' - '"spark","log":[""]}, {"id":1,"state":"busy","kind":"spark","log"' - ':[""]},{"id":2,"state":"busy","kind":"sql","log"' - ':[""]}]}') + http_client.get_sessions.return_value = json.loads( + '{"from":0,"total":3,"sessions":[{"id":0,"state":"idle","kind":' + '"spark","log":[""]}, {"id":1,"state":"busy","kind":"spark","log"' + ':[""]},{"id":2,"state":"busy","kind":"sql","log"' + ':[""]}]}' + ) controller._http_client = MagicMock(return_value=http_client) controller._livy_session = MagicMock() @@ -156,21 +165,27 @@ def test_cleanup_endpoint(): @with_setup(_setup, _teardown) def test_delete_session_by_id_existent_non_managed(): http_client = MagicMock() - http_client.get_session.return_value = json.loads('{"id":0,"state":"starting","kind":"spark","log":[]}') + http_client.get_session.return_value = json.loads( + '{"id":0,"state":"starting","kind":"spark","log":[]}' + ) controller._http_client = MagicMock(return_value=http_client) session = MagicMock() controller._livy_session = MagicMock(return_value=session) controller.delete_session_by_id("conn_str", 0) - controller._livy_session.assert_called_once_with(http_client, {"kind": "spark"}, ipython_display, 0) + controller._livy_session.assert_called_once_with( + http_client, {"kind": "spark"}, ipython_display, 0 + ) session.delete.assert_called_once_with() @with_setup(_setup, _teardown) def test_delete_session_by_id_existent_managed(): name = "name" - controller.session_manager.get_session_name_by_id_endpoint = MagicMock(return_value=name) + controller.session_manager.get_session_name_by_id_endpoint = MagicMock( + return_value=name + ) controller.session_manager.get_sessions_list = MagicMock(return_value=[name]) controller.delete_session_by_name = MagicMock() @@ -190,6 +205,7 @@ def test_delete_session_by_id_non_existent(): controller.delete_session_by_id("conn_str", 0) + @with_setup(_setup, _teardown) def test_get_app_id(): chosen_client = MagicMock() @@ -199,6 +215,7 @@ def test_get_app_id(): assert_equals(result, chosen_client.get_app_id.return_value) chosen_client.get_app_id.assert_called_with() + @with_setup(_setup, _teardown) def test_get_driver_log(): chosen_client = MagicMock() @@ -208,6 +225,7 @@ def test_get_driver_log(): assert_equals(result, chosen_client.get_driver_log_url.return_value) chosen_client.get_driver_log_url.assert_called_with() + @with_setup(_setup, _teardown) def test_get_logs(): chosen_client = MagicMock() @@ -217,19 +235,24 @@ def test_get_logs(): assert_equals(result, chosen_client.get_logs.return_value) chosen_client.get_logs.assert_called_with() + @with_setup(_setup, _teardown) @raises(SessionManagementException) def test_get_logs_error(): chosen_client = MagicMock() - controller.get_session_by_name_or_default = MagicMock(side_effect=SessionManagementException('THERE WAS A SPOOKY GHOST')) + controller.get_session_by_name_or_default = MagicMock( + side_effect=SessionManagementException("THERE WAS A SPOOKY GHOST") + ) result = controller.get_logs() + @with_setup(_setup, _teardown) def test_get_session_id_for_client(): assert controller.get_session_id_for_client("name") is not None client_manager.get_session_id_for_client.assert_called_once_with("name") + @with_setup(_setup, _teardown) def test_get_spark_ui_url(): chosen_client = MagicMock() @@ -239,6 +262,7 @@ def test_get_spark_ui_url(): assert_equals(result, chosen_client.get_spark_ui_url.return_value) chosen_client.get_spark_ui_url.assert_called_with() + @with_setup(_setup, _teardown) def test_add_session_throws_when_session_start_fails(): name = "name" @@ -259,18 +283,22 @@ def test_add_session_throws_when_session_start_fails(): session.start.assert_called_once() controller.session_manager.add_session.assert_not_called + @with_setup(_setup, _teardown) def test_add_session_cleanup_when_timeouts_and_session_posted_to_livy(): pass + @with_setup(_setup, _teardown) def test_add_session_cleanup_when_timeouts_and_session_posted_to_livy(): _do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy=True) + @with_setup(_setup, _teardown) def test_add_session_cleanup_when_timeouts_and_session_not_posted_to_livy(): _do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy=False) + def _do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy): name = "name" properties = {"kind": "spark"} @@ -300,6 +328,7 @@ def _do_test_add_session_cleanup_when_timeouts(is_session_posted_to_livy): else: session.delete.assert_not_called() + @with_setup(_setup, _teardown) def test_add_session_cleanup_when_session_delete_throws(): name = "name" diff --git a/sparkmagic/sparkmagic/tests/test_sparkevents.py b/sparkmagic/sparkmagic/tests/test_sparkevents.py index 340c378b4..116d2b7de 100644 --- a/sparkmagic/sparkmagic/tests/test_sparkevents.py +++ b/sparkmagic/sparkmagic/tests/test_sparkevents.py @@ -27,9 +27,11 @@ def _teardown(): @with_setup(_setup, _teardown) def test_emit_library_loaded_event(): event_name = constants.LIBRARY_LOADED_EVENT - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + ] spark_events.emit_library_loaded_event() @@ -42,13 +44,15 @@ def test_emit_cluster_change_event(): status_code = 200 event_name = constants.CLUSTER_CHANGE_EVENT - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.CLUSTER_DNS_NAME, guid1), - (constants.STATUS_CODE, status_code), - (constants.SUCCESS, True), - (constants.ERROR_MESSAGE, None)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.CLUSTER_DNS_NAME, guid1), + (constants.STATUS_CODE, status_code), + (constants.SUCCESS, True), + (constants.ERROR_MESSAGE, None), + ] spark_events.emit_cluster_change_event(guid1, status_code, True, None) @@ -60,11 +64,13 @@ def test_emit_cluster_change_event(): def test_emit_session_creation_start_event(): language = constants.SESSION_KIND_SPARK event_name = constants.SESSION_CREATION_START_EVENT - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + ] spark_events.emit_session_creation_start_event(guid1, language) @@ -79,18 +85,22 @@ def test_emit_session_creation_end_event(): event_name = constants.SESSION_CREATION_END_EVENT status = constants.BUSY_SESSION_STATUS session_id = 0 - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.STATUS, status), - (constants.SUCCESS, True), - (constants.EXCEPTION_TYPE, ""), - (constants.EXCEPTION_MESSAGE, "")] - - spark_events.emit_session_creation_end_event(guid1, language, session_id, status, True, "", "") + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.STATUS, status), + (constants.SUCCESS, True), + (constants.EXCEPTION_TYPE, ""), + (constants.EXCEPTION_MESSAGE, ""), + ] + + spark_events.emit_session_creation_end_event( + guid1, language, session_id, status, True, "", "" + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -103,13 +113,15 @@ def test_emit_session_deletion_start_event(): event_name = constants.SESSION_DELETION_START_EVENT status = constants.BUSY_SESSION_STATUS session_id = 0 - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.STATUS, status)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.STATUS, status), + ] spark_events.emit_session_deletion_start_event(guid1, language, session_id, status) @@ -124,18 +136,22 @@ def test_emit_session_deletion_end_event(): event_name = constants.SESSION_DELETION_END_EVENT status = constants.BUSY_SESSION_STATUS session_id = 0 - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.STATUS, status), - (constants.SUCCESS, True), - (constants.EXCEPTION_TYPE, ""), - (constants.EXCEPTION_MESSAGE, "")] - - spark_events.emit_session_deletion_end_event(guid1, language, session_id, status, True, "", "") + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.STATUS, status), + (constants.SUCCESS, True), + (constants.EXCEPTION_TYPE, ""), + (constants.EXCEPTION_MESSAGE, ""), + ] + + spark_events.emit_session_deletion_end_event( + guid1, language, session_id, status, True, "", "" + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -148,15 +164,19 @@ def test_emit_statement_execution_start_event(): session_id = 7 event_name = constants.STATEMENT_EXECUTION_START_EVENT - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.STATEMENT_GUID, guid2)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.STATEMENT_GUID, guid2), + ] - spark_events.emit_statement_execution_start_event(guid1, language, session_id, guid2) + spark_events.emit_statement_execution_start_event( + guid1, language, session_id, guid2 + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -170,23 +190,33 @@ def test_emit_statement_execution_end_event(): statement_id = 400 event_name = constants.STATEMENT_EXECUTION_END_EVENT success = True - exception_type = '' - exception_message = 'foo' - - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.STATEMENT_GUID, guid2), - (constants.STATEMENT_ID, statement_id), - (constants.SUCCESS, success), - (constants.EXCEPTION_TYPE, exception_type), - (constants.EXCEPTION_MESSAGE, exception_message)] - - spark_events.emit_statement_execution_end_event(guid1, language, session_id, guid2, statement_id, success, - exception_type, exception_message) + exception_type = "" + exception_message = "foo" + + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.STATEMENT_GUID, guid2), + (constants.STATEMENT_ID, statement_id), + (constants.SUCCESS, success), + (constants.EXCEPTION_TYPE, exception_type), + (constants.EXCEPTION_MESSAGE, exception_message), + ] + + spark_events.emit_statement_execution_end_event( + guid1, + language, + session_id, + guid2, + statement_id, + success, + exception_type, + exception_message, + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -198,23 +228,26 @@ def test_emit_sql_execution_start_event(): event_name = constants.SQL_EXECUTION_START_EVENT session_id = 22 language = constants.SESSION_KIND_SPARK - samplemethod = 'sample' + samplemethod = "sample" maxrows = 12 samplefraction = 0.5 - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.SQL_GUID, guid2), - (constants.SAMPLE_METHOD, samplemethod), - (constants.MAX_ROWS, maxrows), - (constants.SAMPLE_FRACTION, samplefraction)] - - spark_events.emit_sql_execution_start_event(guid1, language, session_id, guid2, samplemethod, - maxrows, samplefraction) + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.SQL_GUID, guid2), + (constants.SAMPLE_METHOD, samplemethod), + (constants.MAX_ROWS, maxrows), + (constants.SAMPLE_FRACTION, samplefraction), + ] + + spark_events.emit_sql_execution_start_event( + guid1, language, session_id, guid2, samplemethod, maxrows, samplefraction + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -227,23 +260,33 @@ def test_emit_sql_execution_end_event(): session_id = 17 language = constants.SESSION_KIND_SPARK success = False - exception_type = 'ValueError' - exception_message = 'You screwed up' - - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.SESSION_GUID, guid1), - (constants.LIVY_KIND, language), - (constants.SESSION_ID, session_id), - (constants.SQL_GUID, guid2), - (constants.STATEMENT_GUID, guid3), - (constants.SUCCESS, success), - (constants.EXCEPTION_TYPE, exception_type), - (constants.EXCEPTION_MESSAGE, exception_message)] - - spark_events.emit_sql_execution_end_event(guid1, language, session_id, guid2, guid3, - success, exception_type, exception_message) + exception_type = "ValueError" + exception_message = "You screwed up" + + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.SESSION_GUID, guid1), + (constants.LIVY_KIND, language), + (constants.SESSION_ID, session_id), + (constants.SQL_GUID, guid2), + (constants.STATEMENT_GUID, guid3), + (constants.SUCCESS, success), + (constants.EXCEPTION_TYPE, exception_type), + (constants.EXCEPTION_MESSAGE, exception_message), + ] + + spark_events.emit_sql_execution_end_event( + guid1, + language, + session_id, + guid2, + guid3, + success, + exception_type, + exception_message, + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -253,15 +296,17 @@ def test_emit_sql_execution_end_event(): @with_setup(_setup, _teardown) def test_emit_magic_execution_start_event(): event_name = constants.MAGIC_EXECUTION_START_EVENT - magic_name = 'sql' + magic_name = "sql" language = constants.SESSION_KIND_SPARKR - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.MAGIC_NAME, magic_name), - (constants.LIVY_KIND, language), - (constants.MAGIC_GUID, guid1)] + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.MAGIC_NAME, magic_name), + (constants.LIVY_KIND, language), + (constants.MAGIC_GUID, guid1), + ] spark_events.emit_magic_execution_start_event(magic_name, language, guid1) @@ -273,24 +318,27 @@ def test_emit_magic_execution_start_event(): @with_setup(_setup, _teardown) def test_emit_magic_execution_end_event(): event_name = constants.MAGIC_EXECUTION_END_EVENT - magic_name = 'sql' + magic_name = "sql" language = constants.SESSION_KIND_SPARKR success = True - exception_type = '' - exception_message = '' - - kwargs_list = [(INSTANCE_ID, get_instance_id()), - (EVENT_NAME, event_name), - (TIMESTAMP, time_stamp), - (constants.MAGIC_NAME, magic_name), - (constants.LIVY_KIND, language), - (constants.MAGIC_GUID, guid1), - (constants.SUCCESS, success), - (constants.EXCEPTION_TYPE, exception_type), - (constants.EXCEPTION_MESSAGE, exception_message)] - - spark_events.emit_magic_execution_end_event(magic_name, language, guid1, success, - exception_type, exception_message) + exception_type = "" + exception_message = "" + + kwargs_list = [ + (INSTANCE_ID, get_instance_id()), + (EVENT_NAME, event_name), + (TIMESTAMP, time_stamp), + (constants.MAGIC_NAME, magic_name), + (constants.LIVY_KIND, language), + (constants.MAGIC_GUID, guid1), + (constants.SUCCESS, success), + (constants.EXCEPTION_TYPE, exception_type), + (constants.EXCEPTION_MESSAGE, exception_message), + ] + + spark_events.emit_magic_execution_end_event( + magic_name, language, guid1, success, exception_type, exception_message + ) spark_events._verify_language_ok.assert_called_once_with(language) spark_events.get_utc_date_time.assert_called_with() @@ -304,4 +352,4 @@ def test_magic_verify_language_ok(): @raises(AssertionError) def test_magic_verify_language_ok_error(): - SparkEvents()._verify_language_ok('NYARGLEBARGLE') + SparkEvents()._verify_language_ok("NYARGLEBARGLE") diff --git a/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py b/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py index f901c8e48..6d383b642 100644 --- a/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py +++ b/sparkmagic/sparkmagic/tests/test_sparkkernelbase.py @@ -16,8 +16,9 @@ class TestSparkKernel(SparkKernelBase): def __init__(self): kwargs = {"testing": True} - super(TestSparkKernel, self).__init__(None, None, None, None, None, LANG_PYTHON, user_code_parser, - **kwargs) + super(TestSparkKernel, self).__init__( + None, None, None, None, None, LANG_PYTHON, user_code_parser, **kwargs + ) def _setup(): @@ -25,8 +26,9 @@ def _setup(): kernel = TestSparkKernel() - kernel._execute_cell_for_user = execute_cell_mock = MagicMock(return_value={'test': 'ing', 'a': 'b', - 'status': 'ok'}) + kernel._execute_cell_for_user = execute_cell_mock = MagicMock( + return_value={"test": "ing", "a": "b", "status": "ok"} + ) kernel._do_shutdown_ipykernel = do_shutdown_mock = MagicMock() kernel.ipython_display = ipython_display = MagicMock() @@ -78,14 +80,16 @@ def test_execute_alerts_user_if_an_unexpected_error_happens(): @with_setup(_setup, _teardown) def test_execute_throws_if_fatal_error_happens_for_execution(): # Verify that the kernel sends the error from Python execution's context to the user - fatal_error = u"Error." - message = "{}\nException details:\n\t\"{}\"".format(fatal_error, fatal_error) + fatal_error = "Error." + message = '{}\nException details:\n\t"{}"'.format(fatal_error, fatal_error) reply_content = dict() - reply_content[u"status"] = u"error" - reply_content[u"evalue"] = fatal_error + reply_content["status"] = "error" + reply_content["evalue"] = fatal_error execute_cell_mock.return_value = reply_content - ret = kernel._execute_cell(code, False, shutdown_if_error=True, log_if_error=fatal_error) + ret = kernel._execute_cell( + code, False, shutdown_if_error=True, log_if_error=fatal_error + ) assert ret is execute_cell_mock.return_value assert kernel._fatal_error == message @@ -118,55 +122,80 @@ def test_shutdown_cleans_up(): def test_register_auto_viz(): kernel._register_auto_viz() - assert call("from autovizwidget.widget.utils import display_dataframe\nip = get_ipython()\nip.display_formatter" - ".ipython_display_formatter.for_type_by_name('pandas.core.frame', 'DataFrame', display_dataframe)", - True, False, None, False) in execute_cell_mock.mock_calls + assert ( + call( + "from autovizwidget.widget.utils import display_dataframe\nip = get_ipython()\nip.display_formatter" + ".ipython_display_formatter.for_type_by_name('pandas.core.frame', 'DataFrame', display_dataframe)", + True, + False, + None, + False, + ) + in execute_cell_mock.mock_calls + ) @with_setup(_setup, _teardown) def test_change_language(): kernel._change_language() - assert call("%%_do_not_call_change_language -l {}\n ".format(LANG_PYTHON), - True, False, None, False) in execute_cell_mock.mock_calls + assert ( + call( + "%%_do_not_call_change_language -l {}\n ".format(LANG_PYTHON), + True, + False, + None, + False, + ) + in execute_cell_mock.mock_calls + ) @with_setup(_setup, _teardown) def test_load_magics(): kernel._load_magics_extension() - assert call("%load_ext sparkmagic.kernels", True, False, None, False) in execute_cell_mock.mock_calls + assert ( + call("%load_ext sparkmagic.kernels", True, False, None, False) + in execute_cell_mock.mock_calls + ) @with_setup(_setup, _teardown) def test_delete_session(): kernel._delete_session() - assert call("%%_do_not_call_delete_session\n ", True, False) in execute_cell_mock.mock_calls + assert ( + call("%%_do_not_call_delete_session\n ", True, False) + in execute_cell_mock.mock_calls + ) + -@patch.object(ipykernel.ipkernel.IPythonKernel, 'do_execute') +@patch.object(ipykernel.ipkernel.IPythonKernel, "do_execute") @with_setup(_teardown) def test_execute_cell_for_user_ipykernel5(mock_ipy_execute): import sys + if sys.version_info.major == 2: from unittest import SkipTest + raise SkipTest("Python 3 only") else: import asyncio mock_ipy_execute_result = asyncio.Future() - mock_ipy_execute_result.set_result({'status': 'OK'}) + mock_ipy_execute_result.set_result({"status": "OK"}) mock_ipy_execute.return_value = mock_ipy_execute_result - actual_result = TestSparkKernel()._execute_cell_for_user(code='Foo', silent=True) + actual_result = TestSparkKernel()._execute_cell_for_user(code="Foo", silent=True) - assert {'status': 'OK'} == actual_result + assert {"status": "OK"} == actual_result -@patch.object(ipykernel.ipkernel.IPythonKernel, 'do_execute') +@patch.object(ipykernel.ipkernel.IPythonKernel, "do_execute") @with_setup(_teardown) def test_execute_cell_for_user_ipykernel4(mock_ipy_execute): - mock_ipy_execute.return_value = {'status': 'OK'} + mock_ipy_execute.return_value = {"status": "OK"} - actual_result = TestSparkKernel()._execute_cell_for_user(code='Foo', silent=True) + actual_result = TestSparkKernel()._execute_cell_for_user(code="Foo", silent=True) - assert {'status': 'OK'} == actual_result + assert {"status": "OK"} == actual_result diff --git a/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py b/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py index 19b5999fd..5084a4f88 100644 --- a/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py +++ b/sparkmagic/sparkmagic/tests/test_sparkmagicsbase.py @@ -4,13 +4,25 @@ from nose.tools import with_setup, assert_equals, assert_raises, raises from sparkmagic.utils.configuration import get_livy_kind -from sparkmagic.utils.constants import LANGS_SUPPORTED, SESSION_KIND_PYSPARK, SESSION_KIND_SPARK, \ - IDLE_SESSION_STATUS, BUSY_SESSION_STATUS, MIMETYPE_TEXT_PLAIN, EXPECTED_ERROR_MSG +from sparkmagic.utils.constants import ( + LANGS_SUPPORTED, + SESSION_KIND_PYSPARK, + SESSION_KIND_SPARK, + IDLE_SESSION_STATUS, + BUSY_SESSION_STATUS, + MIMETYPE_TEXT_PLAIN, + EXPECTED_ERROR_MSG, +) from sparkmagic.magics.sparkmagicsbase import SparkMagicBase -from sparkmagic.livyclientlib.exceptions import DataFrameParseException, BadUserDataException, SparkStatementException +from sparkmagic.livyclientlib.exceptions import ( + DataFrameParseException, + BadUserDataException, + SparkStatementException, +) from sparkmagic.livyclientlib.sqlquery import SQLQuery from sparkmagic.livyclientlib.sparkstorecommand import SparkStoreCommand + def _setup(): global magic, session, shell, ipython_display shell = MagicMock() @@ -22,6 +34,7 @@ def _setup(): magic.ipython_display = MagicMock() conf.override_all({}) + def _teardown(): pass @@ -119,22 +132,24 @@ def test_print_endpoint_info(): current_session_id = 1 session1 = MagicMock() session1.id = 1 - session1.get_row_html.return_value = u"""""" + session1.get_row_html.return_value = """""" session2 = MagicMock() session2.id = 3 - session2.get_row_html.return_value = u"""""" + session2.get_row_html.return_value = """""" magic._print_endpoint_info([session2, session1], current_session_id) - magic.ipython_display.html.assert_called_once_with(u"""
Magic
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
row1
row1
row1
row1
row1
row1
row1
row1
1app1234sparkidleLinkLinkuserTest\u2714
1app1234sparkidleLinkLinkuserTest\u2714
3app5069pysparkbusyuserTest2
3app5069pysparkbusyuserTest2
row1
row1
row2
row2
+ magic.ipython_display.html.assert_called_once_with( + """
\ \ -
IDYARN Application IDKindStateSpark UIDriver logUserCurrent session?
row1
row2
""") +""" + ) @with_setup(_setup, _teardown) def test_print_empty_endpoint_info(): current_session_id = None magic._print_endpoint_info([], current_session_id) - magic.ipython_display.html.assert_called_once_with(u'No active sessions.') + magic.ipython_display.html.assert_called_once_with("No active sessions.") @with_setup(_setup, _teardown) @@ -146,7 +161,10 @@ def test_send_to_spark_should_raise_when_variable_value_is_none(): max_rows = 25000 magic.shell.user_ns[input_variable_name] = None - magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None) + magic.do_send_to_spark( + "", input_variable_name, var_type, output_variable_name, max_rows, None + ) + @with_setup(_setup, _teardown) @raises(BadUserDataException) @@ -158,7 +176,10 @@ def test_send_to_spark_should_raise_when_type_is_incorrect(): max_rows = 25000 magic.shell.user_ns[input_variable_name] = input_variable_value - magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None) + magic.do_send_to_spark( + "", input_variable_name, var_type, output_variable_name, max_rows, None + ) + @with_setup(_setup, _teardown) def test_send_to_spark_should_print_error_when_str_command_failed(): @@ -169,13 +190,20 @@ def test_send_to_spark_should_print_error_when_str_command_failed(): output_value = "error" max_rows = 25000 magic.shell.user_ns[input_variable_name] = input_variable_value - magic.spark_controller.run_command.return_value = (False, output_value, "text/plain") + magic.spark_controller.run_command.return_value = ( + False, + output_value, + "text/plain", + ) - magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None) + magic.do_send_to_spark( + "", input_variable_name, var_type, output_variable_name, max_rows, None + ) magic.ipython_display.send_error.assert_called_once_with(output_value) assert not magic.ipython_display.write.called + @with_setup(_setup, _teardown) def test_send_to_spark_should_print_error_when_df_command_failed(): input_variable_name = "x_in" @@ -185,13 +213,20 @@ def test_send_to_spark_should_print_error_when_df_command_failed(): output_value = "error" max_rows = 25000 magic.shell.user_ns[input_variable_name] = input_variable_value - magic.spark_controller.run_command.return_value = (False, output_value, "text/plain") + magic.spark_controller.run_command.return_value = ( + False, + output_value, + "text/plain", + ) - magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None) + magic.do_send_to_spark( + "", input_variable_name, var_type, output_variable_name, max_rows, None + ) magic.ipython_display.send_error.assert_called_once_with(output_value) assert not magic.ipython_display.write.called + @with_setup(_setup, _teardown) def test_send_to_spark_should_name_the_output_variable_the_same_as_input_name_when_custom_name_not_provided(): input_variable_name = "x_in" @@ -201,13 +236,18 @@ def test_send_to_spark_should_name_the_output_variable_the_same_as_input_name_wh max_rows = 25000 magic.shell.user_ns[input_variable_name] = input_variable_value magic.spark_controller.run_command.return_value = (True, output_value, "text/plain") - expected_message = u'Successfully passed \'{}\' as \'{}\' to Spark kernel'.format(input_variable_name, input_variable_name) + expected_message = "Successfully passed '{}' as '{}' to Spark kernel".format( + input_variable_name, input_variable_name + ) - magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None) + magic.do_send_to_spark( + "", input_variable_name, var_type, output_variable_name, max_rows, None + ) magic.ipython_display.write.assert_called_once_with(expected_message) assert not magic.ipython_display.send_error.called + @with_setup(_setup, _teardown) def test_send_to_spark_should_write_successfully_when_everything_is_correct(): input_variable_name = "x_in" @@ -217,42 +257,81 @@ def test_send_to_spark_should_write_successfully_when_everything_is_correct(): var_type = "str" magic.shell.user_ns[input_variable_name] = input_variable_value magic.spark_controller.run_command.return_value = (True, output_value, "text/plain") - expected_message = u'Successfully passed \'{}\' as \'{}\' to Spark kernel'.format(input_variable_name, output_variable_name) + expected_message = "Successfully passed '{}' as '{}' to Spark kernel".format( + input_variable_name, output_variable_name + ) - magic.do_send_to_spark("", input_variable_name, var_type, output_variable_name, max_rows, None) + magic.do_send_to_spark( + "", input_variable_name, var_type, output_variable_name, max_rows, None + ) magic.ipython_display.write.assert_called_once_with(expected_message) assert not magic.ipython_display.send_error.called + @with_setup(_setup, _teardown) def test_spark_execution_without_output_var(): output_var = None - - magic.spark_controller.run_command.return_value = (True,'out',MIMETYPE_TEXT_PLAIN) + + magic.spark_controller.run_command.return_value = (True, "out", MIMETYPE_TEXT_PLAIN) magic.execute_spark("", output_var, None, None, None, session, None) - magic.ipython_display.write.assert_called_once_with('out') + magic.ipython_display.write.assert_called_once_with("out") assert not magic.spark_controller._spark_store_command.called - magic.spark_controller.run_command.return_value = (False,'out',MIMETYPE_TEXT_PLAIN) - assert_raises(SparkStatementException, magic.execute_spark,"", output_var, None, None, None, session, True) + magic.spark_controller.run_command.return_value = ( + False, + "out", + MIMETYPE_TEXT_PLAIN, + ) + assert_raises( + SparkStatementException, + magic.execute_spark, + "", + output_var, + None, + None, + None, + session, + True, + ) assert not magic.spark_controller._spark_store_command.called + @with_setup(_setup, _teardown) def test_spark_execution_with_output_var(): mockSparkCommand = MagicMock() magic._spark_store_command = MagicMock(return_value=mockSparkCommand) output_var = "var_name" - df = 'df' + df = "df" - magic.spark_controller.run_command.side_effect = [(True,'out',MIMETYPE_TEXT_PLAIN), df] + magic.spark_controller.run_command.side_effect = [ + (True, "out", MIMETYPE_TEXT_PLAIN), + df, + ] magic.execute_spark("", output_var, None, None, None, session, True) - magic.ipython_display.write.assert_called_once_with('out') - magic._spark_store_command.assert_called_once_with(output_var, None, None, None, True) + magic.ipython_display.write.assert_called_once_with("out") + magic._spark_store_command.assert_called_once_with( + output_var, None, None, None, True + ) assert shell.user_ns[output_var] == df magic.spark_controller.run_command.side_effect = None - magic.spark_controller.run_command.return_value = (False,'out',MIMETYPE_TEXT_PLAIN) - assert_raises(SparkStatementException, magic.execute_spark,"", output_var, None, None, None, session, True) + magic.spark_controller.run_command.return_value = ( + False, + "out", + MIMETYPE_TEXT_PLAIN, + ) + assert_raises( + SparkStatementException, + magic.execute_spark, + "", + output_var, + None, + None, + None, + session, + True, + ) @with_setup(_setup, _teardown) @@ -261,34 +340,75 @@ def test_spark_exception_with_output_var(): magic._spark_store_command = MagicMock(return_value=mockSparkCommand) exception = BadUserDataException("Ka-boom!") output_var = "var_name" - df = 'df' - - magic.spark_controller.run_command.side_effect = [(True,'out',MIMETYPE_TEXT_PLAIN), exception] - assert_raises(BadUserDataException, magic.execute_spark,"", output_var, None, None, None, session, True) - magic.ipython_display.write.assert_called_once_with('out') - magic._spark_store_command.assert_called_once_with(output_var, None, None, None, True) + df = "df" + + magic.spark_controller.run_command.side_effect = [ + (True, "out", MIMETYPE_TEXT_PLAIN), + exception, + ] + assert_raises( + BadUserDataException, + magic.execute_spark, + "", + output_var, + None, + None, + None, + session, + True, + ) + magic.ipython_display.write.assert_called_once_with("out") + magic._spark_store_command.assert_called_once_with( + output_var, None, None, None, True + ) assert shell.user_ns == {} + @with_setup(_setup, _teardown) def test_spark_statement_exception(): mockSparkCommand = MagicMock() magic._spark_store_command = MagicMock(return_value=mockSparkCommand) exception = BadUserDataException("Ka-boom!") - magic.spark_controller.run_command.side_effect = [(False, 'out', "text/plain"), exception] - assert_raises(SparkStatementException, magic.execute_spark,"", None, None, None, None, session, True) + magic.spark_controller.run_command.side_effect = [ + (False, "out", "text/plain"), + exception, + ] + assert_raises( + SparkStatementException, + magic.execute_spark, + "", + None, + None, + None, + None, + session, + True, + ) magic.spark_controller.cleanup.assert_not_called() + @with_setup(_setup, _teardown) def test_spark_statement_exception_shutdowns_livy_session(): - conf.override_all({ - "shutdown_session_on_spark_statement_errors": True - }) + conf.override_all({"shutdown_session_on_spark_statement_errors": True}) mockSparkCommand = MagicMock() magic._spark_store_command = MagicMock(return_value=mockSparkCommand) exception = BadUserDataException("Ka-boom!") - magic.spark_controller.run_command.side_effect = [(False, 'out', "text/plain"), exception] - assert_raises(SparkStatementException, magic.execute_spark,"", None, None, None, None, session, True) + magic.spark_controller.run_command.side_effect = [ + (False, "out", "text/plain"), + exception, + ] + assert_raises( + SparkStatementException, + magic.execute_spark, + "", + None, + None, + None, + None, + session, + True, + ) magic.spark_controller.cleanup.assert_called_once() diff --git a/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py b/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py index c011c6aae..d73bb7b27 100644 --- a/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py +++ b/sparkmagic/sparkmagic/tests/test_sparkstorecommand.py @@ -12,17 +12,20 @@ backup_conf_defaults = None + def _setup(): global backup_conf_defaults backup_conf_defaults = { - 'samplemethod' : conf.default_samplemethod(), - 'maxrows': conf.default_maxrows(), - 'samplefraction': conf.default_samplefraction() + "samplemethod": conf.default_samplemethod(), + "maxrows": conf.default_maxrows(), + "samplefraction": conf.default_samplefraction(), } + def _teardown(): conf.override_all(backup_conf_defaults) + @with_setup(_setup, _teardown) def test_to_command_pyspark(): variable_name = "var_name" @@ -54,7 +57,9 @@ def test_to_command_r(): def test_to_command_invalid(): variable_name = "var_name" sparkcommand = SparkStoreCommand(variable_name) - assert_raises(BadUserDataException, sparkcommand.to_command, "invalid", variable_name) + assert_raises( + BadUserDataException, sparkcommand.to_command, "invalid", variable_name + ) @with_setup(_setup, _teardown) @@ -63,7 +68,9 @@ def test_sparkstorecommand_initializes(): samplemethod = "take" maxrows = 120 samplefraction = 0.6 - sparkcommand = SparkStoreCommand(variable_name, samplemethod, maxrows, samplefraction) + sparkcommand = SparkStoreCommand( + variable_name, samplemethod, maxrows, samplefraction + ) assert_equals(sparkcommand.samplemethod, samplemethod) assert_equals(sparkcommand.maxrows, maxrows) assert_equals(sparkcommand.samplefraction, samplefraction) @@ -79,137 +86,242 @@ def test_sparkstorecommand_loads_defaults(): conf.override_all(defaults) variable_name = "var_name" sparkcommand = SparkStoreCommand(variable_name) - assert_equals(sparkcommand.samplemethod, defaults[conf.default_samplemethod.__name__]) + assert_equals( + sparkcommand.samplemethod, defaults[conf.default_samplemethod.__name__] + ) assert_equals(sparkcommand.maxrows, defaults[conf.default_maxrows.__name__]) - assert_equals(sparkcommand.samplefraction, defaults[conf.default_samplefraction.__name__]) + assert_equals( + sparkcommand.samplefraction, defaults[conf.default_samplefraction.__name__] + ) @with_setup(_setup, _teardown) def test_pyspark_livy_sampling_options(): variable_name = "var_name" - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=120) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=-1) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME))) + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=120) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=-1) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.25, maxrows=-1 + ) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): " + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234 + ) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): " + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234 + ) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): " + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(100): print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + @with_setup(_setup, _teardown) def test_scala_livy_sampling_options(): variable_name = "abc" - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=100) - assert_equals(sparkcommand._scala_command(variable_name), - Command('{}.toJSON.take(100).foreach(println)'.format(variable_name))) + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=100) + assert_equals( + sparkcommand._scala_command(variable_name), + Command("{}.toJSON.take(100).foreach(println)".format(variable_name)), + ) + + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=-1) + assert_equals( + sparkcommand._scala_command(variable_name), + Command("{}.toJSON.collect.foreach(println)".format(variable_name)), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.25, maxrows=-1 + ) + assert_equals( + sparkcommand._scala_command(variable_name), + Command( + "{}.toJSON.sample(false, 0.25).collect.foreach(println)".format( + variable_name + ) + ), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234 + ) + assert_equals( + sparkcommand._scala_command(variable_name), + Command( + "{}.toJSON.sample(false, 0.33).take(3234).foreach(println)".format( + variable_name + ) + ), + ) - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=-1) - assert_equals(sparkcommand._scala_command(variable_name), - Command('{}.toJSON.collect.foreach(println)'.format(variable_name))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sparkcommand._scala_command(variable_name), - Command('{}.toJSON.sample(false, 0.25).collect.foreach(println)'.format(variable_name))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sparkcommand._scala_command(variable_name), - Command('{}.toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(variable_name))) - sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100) - assert_equals(sparkcommand._scala_command(variable_name), - Command('{}.toJSON.take(100).foreach(println)'.format(variable_name))) + assert_equals( + sparkcommand._scala_command(variable_name), + Command("{}.toJSON.take(100).foreach(println)".format(variable_name)), + ) + @with_setup(_setup, _teardown) def test_r_livy_sampling_options(): variable_name = "abc" - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=100) - - assert_equals(sparkcommand._r_command(variable_name), - Command('for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=-1) - assert_equals(sparkcommand._r_command(variable_name), - Command('for ({} in (jsonlite::toJSON(collect({})))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sparkcommand._r_command(variable_name), - Command('for ({} in (jsonlite::toJSON(collect(sample({}, FALSE, 0.25))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME))) - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sparkcommand._r_command(variable_name), - Command('for ({} in (jsonlite::toJSON(take(sample({}, FALSE, 0.33),3234)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME))) + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=100) + + assert_equals( + sparkcommand._r_command(variable_name), + Command( + "for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=-1) + assert_equals( + sparkcommand._r_command(variable_name), + Command( + "for ({} in (jsonlite::toJSON(collect({})))) {{cat({})}}".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.25, maxrows=-1 + ) + assert_equals( + sparkcommand._r_command(variable_name), + Command( + "for ({} in (jsonlite::toJSON(collect(sample({}, FALSE, 0.25))))) {{cat({})}}".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sparkcommand = SparkStoreCommand( + variable_name, samplemethod="sample", samplefraction=0.33, maxrows=3234 + ) + assert_equals( + sparkcommand._r_command(variable_name), + Command( + "for ({} in (jsonlite::toJSON(take(sample({}, FALSE, 0.33),3234)))) {{cat({})}}".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) sparkcommand = SparkStoreCommand(variable_name, samplemethod=None, maxrows=100) - assert_equals(sparkcommand._r_command(variable_name), - Command('for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) + assert_equals( + sparkcommand._r_command(variable_name), + Command( + "for ({} in (jsonlite::toJSON(take({},100)))) {{cat({})}}".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + @with_setup(_setup, _teardown) def test_execute_code(): spark_events = MagicMock() variable_name = "abc" - sparkcommand = SparkStoreCommand(variable_name, "take", 100, 0.2, spark_events=spark_events) + sparkcommand = SparkStoreCommand( + variable_name, "take", 100, 0.2, spark_events=spark_events + ) sparkcommand.to_command = MagicMock(return_value=MagicMock()) result = """{"z":100, "nullv":null, "y":50} {"z":25, "nullv":null, "y":10}""" - sparkcommand.to_command.return_value.execute = MagicMock(return_value=(True, result, MIMETYPE_TEXT_PLAIN)) + sparkcommand.to_command.return_value.execute = MagicMock( + return_value=(True, result, MIMETYPE_TEXT_PLAIN) + ) session = MagicMock() session.kind = "pyspark" result = sparkcommand.execute(session) - + sparkcommand.to_command.assert_called_once_with(session.kind, variable_name) sparkcommand.to_command.return_value.execute.assert_called_once_with(session) @with_setup(_setup, _teardown) def test_unicode(): - variable_name = u"collect 'รจ'" - - sparkcommand = SparkStoreCommand(variable_name, samplemethod='take', maxrows=120) - assert_equals(sparkcommand._pyspark_command(variable_name), - Command(u'import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, variable_name, - LONG_RANDOM_VARIABLE_NAME))) - assert_equals(sparkcommand._scala_command(variable_name), - Command(u'{}.toJSON.take(120).foreach(println)'.format(variable_name))) - + variable_name = "collect 'รจ'" + + sparkcommand = SparkStoreCommand(variable_name, samplemethod="take", maxrows=120) + assert_equals( + sparkcommand._pyspark_command(variable_name), + Command( + "import sys\nfor {} in {}.toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})".format( + LONG_RANDOM_VARIABLE_NAME, variable_name, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + assert_equals( + sparkcommand._scala_command(variable_name), + Command("{}.toJSON.take(120).foreach(println)".format(variable_name)), + ) diff --git a/sparkmagic/sparkmagic/tests/test_sqlquery.py b/sparkmagic/sparkmagic/tests/test_sqlquery.py index cae49af8d..f685f0042 100644 --- a/sparkmagic/sparkmagic/tests/test_sqlquery.py +++ b/sparkmagic/sparkmagic/tests/test_sqlquery.py @@ -13,7 +13,7 @@ def _setup(): pass - + def _teardown(): pass @@ -54,7 +54,9 @@ def test_sqlquery_loads_defaults(): assert_equals(sqlquery.query, query) assert_equals(sqlquery.samplemethod, defaults[conf.default_samplemethod.__name__]) assert_equals(sqlquery.maxrows, defaults[conf.default_maxrows.__name__]) - assert_equals(sqlquery.samplefraction, defaults[conf.default_samplefraction.__name__]) + assert_equals( + sqlquery.samplefraction, defaults[conf.default_samplefraction.__name__] + ) @with_setup(_setup, _teardown) @@ -69,101 +71,181 @@ def test_sqlquery_rejects_bad_data(): def test_pyspark_livy_sql_options(): query = "abc" - sqlquery = SQLQuery(query, samplemethod='take', maxrows=120) - assert_equals(sqlquery._pyspark_command("sqlContext"), - Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1) - assert_equals(sqlquery._pyspark_command("sqlContext"), - Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sqlquery._pyspark_command("sqlContext"), - Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._pyspark_command("sqlContext"), - Command(u'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._pyspark_command("spark"), - Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME))) + sqlquery = SQLQuery(query, samplemethod="take", maxrows=120) + assert_equals( + sqlquery._pyspark_command("sqlContext"), + Command( + 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1) + assert_equals( + sqlquery._pyspark_command("sqlContext"), + Command( + 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1) + assert_equals( + sqlquery._pyspark_command("sqlContext"), + Command( + 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): ' + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._pyspark_command("sqlContext"), + Command( + 'import sys\nfor {} in sqlContext.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._pyspark_command("spark"), + Command( + 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + @with_setup(_setup, _teardown) def test_scala_livy_sql_options(): query = "abc" - sqlquery = SQLQuery(query, samplemethod='take', maxrows=100) - assert_equals(sqlquery._scala_command("sqlContext"), - Command('sqlContext.sql("""{}""").toJSON.take(100).foreach(println)'.format(query))) - - sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1) - assert_equals(sqlquery._scala_command("sqlContext"), - Command('sqlContext.sql("""{}""").toJSON.collect.foreach(println)'.format(query))) + sqlquery = SQLQuery(query, samplemethod="take", maxrows=100) + assert_equals( + sqlquery._scala_command("sqlContext"), + Command( + 'sqlContext.sql("""{}""").toJSON.take(100).foreach(println)'.format(query) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1) + assert_equals( + sqlquery._scala_command("sqlContext"), + Command( + 'sqlContext.sql("""{}""").toJSON.collect.foreach(println)'.format(query) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1) + assert_equals( + sqlquery._scala_command("sqlContext"), + Command( + 'sqlContext.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format( + query + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._scala_command("sqlContext"), + Command( + 'sqlContext.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format( + query + ) + ), + ) - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sqlquery._scala_command("sqlContext"), - Command('sqlContext.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format(query))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._scala_command("sqlContext"), - Command('sqlContext.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(query))) @with_setup(_setup, _teardown) def test_r_livy_sql_options_spark(): - query = "abc" - sqlquery = SQLQuery(query, samplemethod='take', maxrows=100) - sqlContext = "sqlContext" - - assert_equals(sqlquery._r_command(sqlContext), - Command('for ({} in (jsonlite:::toJSON(take(sql({}, "{}"),100)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1) - assert_equals(sqlquery._r_command(sqlContext), - Command('for ({} in (jsonlite:::toJSON(collect(sql({}, "{}"))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sqlquery._r_command(sqlContext), - Command('for ({} in (jsonlite:::toJSON(collect(sample(sql({}, "{}"), FALSE, 0.25))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._r_command(sqlContext), - Command('for ({} in (jsonlite:::toJSON(take(sample(sql({}, "{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME))) + query = "abc" + sqlquery = SQLQuery(query, samplemethod="take", maxrows=100) + sqlContext = "sqlContext" + + assert_equals( + sqlquery._r_command(sqlContext), + Command( + 'for ({} in (jsonlite:::toJSON(take(sql({}, "{}"),100)))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1) + assert_equals( + sqlquery._r_command(sqlContext), + Command( + 'for ({} in (jsonlite:::toJSON(collect(sql({}, "{}"))))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1) + assert_equals( + sqlquery._r_command(sqlContext), + Command( + 'for ({} in (jsonlite:::toJSON(collect(sample(sql({}, "{}"), FALSE, 0.25))))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._r_command(sqlContext), + Command( + 'for ({} in (jsonlite:::toJSON(take(sample(sql({}, "{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, sqlContext, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) @with_setup(_setup, _teardown) def test_execute_sql(): spark_events = MagicMock() - sqlquery = SQLQuery("HERE IS THE QUERY", "take", 100, 0.2, spark_events=spark_events) + sqlquery = SQLQuery( + "HERE IS THE QUERY", "take", 100, 0.2, spark_events=spark_events + ) sqlquery.to_command = MagicMock(return_value=MagicMock()) result = """{"z":100, "nullv":null, "y":50} {"z":25, "nullv":null, "y":10}""" - sqlquery.to_command.return_value.execute = MagicMock(return_value=(True, result, MIMETYPE_TEXT_PLAIN)) - result_data = pd.DataFrame([{'z': 100, "nullv": None, 'y': 50}, {'z':25, "nullv":None, 'y':10}], columns=['z', "nullv", 'y']) + sqlquery.to_command.return_value.execute = MagicMock( + return_value=(True, result, MIMETYPE_TEXT_PLAIN) + ) + result_data = pd.DataFrame( + [{"z": 100, "nullv": None, "y": 50}, {"z": 25, "nullv": None, "y": 10}], + columns=["z", "nullv", "y"], + ) session = MagicMock() session.kind = "pyspark" result = sqlquery.execute(session) assert_frame_equal(result, result_data) sqlquery.to_command.return_value.execute.assert_called_once_with(session) - spark_events.emit_sql_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, sqlquery.guid, - 'take', 100, 0.2) - spark_events.emit_sql_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, sqlquery.guid, - sqlquery.to_command.return_value.guid, - True, '','') + spark_events.emit_sql_execution_start_event.assert_called_once_with( + session.guid, session.kind, session.id, sqlquery.guid, "take", 100, 0.2 + ) + spark_events.emit_sql_execution_end_event.assert_called_once_with( + session.guid, + session.kind, + session.id, + sqlquery.guid, + sqlquery.to_command.return_value.guid, + True, + "", + "", + ) @with_setup(_setup, _teardown) @@ -177,19 +259,34 @@ def test_execute_sql_no_results(): result1 = "" result_data = pd.DataFrame([]) session = MagicMock() - sqlquery.to_command.return_value.execute.return_value = (True, result1, MIMETYPE_TEXT_PLAIN) + sqlquery.to_command.return_value.execute.return_value = ( + True, + result1, + MIMETYPE_TEXT_PLAIN, + ) session.kind = "spark" result = sqlquery.execute(session) assert_frame_equal(result, result_data) sqlquery.to_command.return_value.execute.assert_called_once_with(session) - spark_events.emit_sql_execution_start_event.assert_called_once_with(session.guid, session.kind, - session.id, sqlquery.guid, - sqlquery.samplemethod, sqlquery.maxrows, - sqlquery.samplefraction) - spark_events.emit_sql_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, sqlquery.guid, - sqlquery.to_command.return_value.guid, - True, "", "") + spark_events.emit_sql_execution_start_event.assert_called_once_with( + session.guid, + session.kind, + session.id, + sqlquery.guid, + sqlquery.samplemethod, + sqlquery.maxrows, + sqlquery.samplefraction, + ) + spark_events.emit_sql_execution_end_event.assert_called_once_with( + session.guid, + session.kind, + session.id, + sqlquery.guid, + sqlquery.to_command.return_value.guid, + True, + "", + "", + ) @with_setup(_setup, _teardown) @@ -197,7 +294,7 @@ def test_execute_sql_failure_emits_event(): spark_events = MagicMock() sqlquery = SQLQuery("HERE IS THE QUERY", "take", 100, 0.2, spark_events) sqlquery.to_command = MagicMock() - sqlquery.to_command.return_value.execute = MagicMock(side_effect=ValueError('yo')) + sqlquery.to_command.return_value.execute = MagicMock(side_effect=ValueError("yo")) session = MagicMock() session.kind = "pyspark" try: @@ -205,93 +302,170 @@ def test_execute_sql_failure_emits_event(): assert False except ValueError: sqlquery.to_command.return_value.execute.assert_called_once_with(session) - spark_events.emit_sql_execution_end_event.assert_called_once_with(session.guid, session.kind, - session.id, sqlquery.guid, - sqlquery.to_command.return_value.guid, - False, 'ValueError', 'yo') + spark_events.emit_sql_execution_end_event.assert_called_once_with( + session.guid, + session.kind, + session.id, + sqlquery.guid, + sqlquery.to_command.return_value.guid, + False, + "ValueError", + "yo", + ) @with_setup(_setup, _teardown) def test_unicode_sql(): - query = u"SELECT 'รจ'" + query = "SELECT 'รจ'" longvar = LONG_RANDOM_VARIABLE_NAME - sqlquery = SQLQuery(query, samplemethod='take', maxrows=120) - assert_equals(sqlquery._pyspark_command("spark"), - Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\ - .format(longvar, query, - longvar))) - assert_equals(sqlquery._scala_command("spark"), - Command(u'spark.sql("""{}""").toJSON.take(120).foreach(println)'.format(query))) - assert_equals(sqlquery._r_command("spark"), - Command(u'for ({} in (jsonlite:::toJSON(take(sql("{}"),120)))) {{cat({})}}'.format(longvar, query, longvar))) + sqlquery = SQLQuery(query, samplemethod="take", maxrows=120) + assert_equals( + sqlquery._pyspark_command("spark"), + Command( + 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'.format( + longvar, query, longvar + ) + ), + ) + assert_equals( + sqlquery._scala_command("spark"), + Command('spark.sql("""{}""").toJSON.take(120).foreach(println)'.format(query)), + ) + assert_equals( + sqlquery._r_command("spark"), + Command( + 'for ({} in (jsonlite:::toJSON(take(sql("{}"),120)))) {{cat({})}}'.format( + longvar, query, longvar + ) + ), + ) + @with_setup(_setup, _teardown) def test_pyspark_livy_sql_options_spark2(): - query = "abc" - sqlquery = SQLQuery(query, samplemethod='take', maxrows=120) - - assert_equals(sqlquery._pyspark_command("spark"), - Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1) - assert_equals(sqlquery._pyspark_command("spark"), - Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sqlquery._pyspark_command("spark"), - Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._pyspark_command("spark"), - Command(u'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' - u'print({})'\ - .format(LONG_RANDOM_VARIABLE_NAME, query, - LONG_RANDOM_VARIABLE_NAME))) + query = "abc" + sqlquery = SQLQuery(query, samplemethod="take", maxrows=120) + + assert_equals( + sqlquery._pyspark_command("spark"), + Command( + 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).take(120): print({})'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1) + assert_equals( + sqlquery._pyspark_command("spark"), + Command( + 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).collect(): print({})'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1) + assert_equals( + sqlquery._pyspark_command("spark"), + Command( + 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.25).collect(): ' + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._pyspark_command("spark"), + Command( + 'import sys\nfor {} in spark.sql(u"""{} """).toJSON(use_unicode=(sys.version_info.major > 2)).sample(False, 0.33).take(3234): ' + "print({})".format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + @with_setup(_setup, _teardown) def test_scala_livy_sql_options_spark2(): - query = "abc" - sqlquery = SQLQuery(query, samplemethod='take', maxrows=100) - - assert_equals(sqlquery._scala_command("spark"), - Command('spark.sql("""{}""").toJSON.take(100).foreach(println)'.format(query))) - - sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1) - assert_equals(sqlquery._scala_command("spark"), - Command('spark.sql("""{}""").toJSON.collect.foreach(println)'.format(query))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sqlquery._scala_command("spark"), - Command('spark.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format(query))) + query = "abc" + sqlquery = SQLQuery(query, samplemethod="take", maxrows=100) + + assert_equals( + sqlquery._scala_command("spark"), + Command('spark.sql("""{}""").toJSON.take(100).foreach(println)'.format(query)), + ) + + sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1) + assert_equals( + sqlquery._scala_command("spark"), + Command('spark.sql("""{}""").toJSON.collect.foreach(println)'.format(query)), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1) + assert_equals( + sqlquery._scala_command("spark"), + Command( + 'spark.sql("""{}""").toJSON.sample(false, 0.25).collect.foreach(println)'.format( + query + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._scala_command("spark"), + Command( + 'spark.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format( + query + ) + ), + ) - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._scala_command("spark"), - Command('spark.sql("""{}""").toJSON.sample(false, 0.33).take(3234).foreach(println)'.format(query))) @with_setup(_setup, _teardown) def test_r_livy_sql_options_spark2(): - query = "abc" - sqlquery = SQLQuery(query, samplemethod='take', maxrows=100) - - assert_equals(sqlquery._r_command("spark"), - Command('for ({} in (jsonlite:::toJSON(take(sql("{}"),100)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='take', maxrows=-1) - assert_equals(sqlquery._r_command("spark"), - Command('for ({} in (jsonlite:::toJSON(collect(sql("{}"))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.25, maxrows=-1) - assert_equals(sqlquery._r_command("spark"), - Command('for ({} in (jsonlite:::toJSON(collect(sample(sql("{}"), FALSE, 0.25))))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME))) - - sqlquery = SQLQuery(query, samplemethod='sample', samplefraction=0.33, maxrows=3234) - assert_equals(sqlquery._r_command("spark"), - Command('for ({} in (jsonlite:::toJSON(take(sample(sql("{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format(LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME))) + query = "abc" + sqlquery = SQLQuery(query, samplemethod="take", maxrows=100) + + assert_equals( + sqlquery._r_command("spark"), + Command( + 'for ({} in (jsonlite:::toJSON(take(sql("{}"),100)))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="take", maxrows=-1) + assert_equals( + sqlquery._r_command("spark"), + Command( + 'for ({} in (jsonlite:::toJSON(collect(sql("{}"))))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.25, maxrows=-1) + assert_equals( + sqlquery._r_command("spark"), + Command( + 'for ({} in (jsonlite:::toJSON(collect(sample(sql("{}"), FALSE, 0.25))))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) + + sqlquery = SQLQuery(query, samplemethod="sample", samplefraction=0.33, maxrows=3234) + assert_equals( + sqlquery._r_command("spark"), + Command( + 'for ({} in (jsonlite:::toJSON(take(sample(sql("{}"), FALSE, 0.33),3234)))) {{cat({})}}'.format( + LONG_RANDOM_VARIABLE_NAME, query, LONG_RANDOM_VARIABLE_NAME + ) + ), + ) diff --git a/sparkmagic/sparkmagic/tests/test_usercodeparser.py b/sparkmagic/sparkmagic/tests/test_usercodeparser.py index b1e54a67b..10cb776bb 100644 --- a/sparkmagic/sparkmagic/tests/test_usercodeparser.py +++ b/sparkmagic/sparkmagic/tests/test_usercodeparser.py @@ -8,67 +8,72 @@ def test_empty_string(): parser = UserCodeParser() - assert_equals(u"", parser.get_code_to_run(u"")) + assert_equals("", parser.get_code_to_run("")) def test_spark_code(): parser = UserCodeParser() - cell = u"my code\nand more" + cell = "my code\nand more" - assert_equals(u"%%spark\nmy code\nand more", parser.get_code_to_run(cell)) + assert_equals("%%spark\nmy code\nand more", parser.get_code_to_run(cell)) def test_local_single(): parser = UserCodeParser() - cell = u"""%local + cell = """%local hi hi hi""" - assert_equals(u"hi\nhi\nhi", parser.get_code_to_run(cell)) + assert_equals("hi\nhi\nhi", parser.get_code_to_run(cell)) def test_local_double(): parser = UserCodeParser() - cell = u"""%%local + cell = """%%local hi hi hi""" - assert_equals(u"hi\nhi\nhi", parser.get_code_to_run(cell)) + assert_equals("hi\nhi\nhi", parser.get_code_to_run(cell)) def test_our_line_magics(): parser = UserCodeParser() magic_name = KernelMagics.info.__name__ - cell = u"%{}".format(magic_name) + cell = "%{}".format(magic_name) - assert_equals(u"%%{}\n ".format(magic_name), parser.get_code_to_run(cell)) + assert_equals("%%{}\n ".format(magic_name), parser.get_code_to_run(cell)) def test_our_line_magics_with_content(): parser = UserCodeParser() magic_name = KernelMagics.info.__name__ - cell = u"""%{} + cell = """%{} my content -more content""".format(magic_name) +more content""".format( + magic_name + ) - assert_equals(u"%%{}\nmy content\nmore content\n ".format(magic_name), parser.get_code_to_run(cell)) + assert_equals( + "%%{}\nmy content\nmore content\n ".format(magic_name), + parser.get_code_to_run(cell), + ) def test_other_cell_magic(): parser = UserCodeParser() - cell = u"""%%magic + cell = """%%magic hi hi hi""" - assert_equals(u"{}".format(cell), parser.get_code_to_run(cell)) + assert_equals("{}".format(cell), parser.get_code_to_run(cell)) def test_other_line_magic(): parser = UserCodeParser() - cell = u"""%magic + cell = """%magic hi hi hi""" @@ -78,36 +83,46 @@ def test_other_line_magic(): def test_scala_code(): parser = UserCodeParser() - cell = u"""/* Place the cursor in the cell and press SHIFT + ENTER to run */ + cell = """/* Place the cursor in the cell and press SHIFT + ENTER to run */ val fruits = sc.textFile("wasb:///example/data/fruits.txt") val yellowThings = sc.textFile("wasb:///example/data/yellowthings.txt")""" - assert_equals(u"%%spark\n{}".format(cell), parser.get_code_to_run(cell)) + assert_equals("%%spark\n{}".format(cell), parser.get_code_to_run(cell)) def test_unicode(): parser = UserCodeParser() - cell = u"print 'รจ๐Ÿ™๐Ÿ™๐Ÿ™๐Ÿ™'" + cell = "print 'รจ๐Ÿ™๐Ÿ™๐Ÿ™๐Ÿ™'" - assert_equals(u"%%spark\n{}".format(cell), parser.get_code_to_run(cell)) + assert_equals("%%spark\n{}".format(cell), parser.get_code_to_run(cell)) def test_unicode_in_magics(): parser = UserCodeParser() magic_name = KernelMagics.info.__name__ - cell = u"""%{} + cell = """%{} my content รจ๐Ÿ™ -more content""".format(magic_name) +more content""".format( + magic_name + ) - assert_equals(u"%%{}\nmy content รจ๐Ÿ™\nmore content\n ".format(magic_name), parser.get_code_to_run(cell)) + assert_equals( + "%%{}\nmy content รจ๐Ÿ™\nmore content\n ".format(magic_name), + parser.get_code_to_run(cell), + ) def test_unicode_in_double_magics(): parser = UserCodeParser() magic_name = KernelMagics.info.__name__ - cell = u"""%%{} + cell = """%%{} my content รจ๐Ÿ™ -more content""".format(magic_name) - - assert_equals(u"%%{}\nmy content รจ๐Ÿ™\nmore content\n ".format(magic_name), parser.get_code_to_run(cell)) +more content""".format( + magic_name + ) + + assert_equals( + "%%{}\nmy content รจ๐Ÿ™\nmore content\n ".format(magic_name), + parser.get_code_to_run(cell), + ) diff --git a/sparkmagic/sparkmagic/tests/test_utils.py b/sparkmagic/sparkmagic/tests/test_utils.py index 72779aedf..71fe67883 100644 --- a/sparkmagic/sparkmagic/tests/test_utils.py +++ b/sparkmagic/sparkmagic/tests/test_utils.py @@ -8,22 +8,32 @@ from sparkmagic.livyclientlib.exceptions import BadUserDataException from sparkmagic.utils.utils import parse_argstring_or_throw, records_to_dataframe from sparkmagic.utils.constants import SESSION_KIND_PYSPARK -from sparkmagic.utils.dataframe_parser import DataframeHtmlParser, cell_contains_dataframe, CellComponentType, cell_components_iter, CellOutputHtmlParser +from sparkmagic.utils.dataframe_parser import ( + DataframeHtmlParser, + cell_contains_dataframe, + CellComponentType, + cell_components_iter, + CellOutputHtmlParser, +) import unittest def test_parse_argstring_or_throw(): - parse_argstring = MagicMock(side_effect=UsageError('OOGABOOGABOOGA')) + parse_argstring = MagicMock(side_effect=UsageError("OOGABOOGABOOGA")) try: - parse_argstring_or_throw(MagicMock(), MagicMock(), parse_argstring=parse_argstring) + parse_argstring_or_throw( + MagicMock(), MagicMock(), parse_argstring=parse_argstring + ) assert False except BadUserDataException as e: assert_equals(str(e), str(parse_argstring.side_effect)) - parse_argstring = MagicMock(side_effect=ValueError('AN UNKNOWN ERROR HAPPENED')) + parse_argstring = MagicMock(side_effect=ValueError("AN UNKNOWN ERROR HAPPENED")) try: - parse_argstring_or_throw(MagicMock(), MagicMock(), parse_argstring=parse_argstring) + parse_argstring_or_throw( + MagicMock(), MagicMock(), parse_argstring=parse_argstring + ) assert False except ValueError as e: assert_is(e, parse_argstring.side_effect) @@ -32,36 +42,51 @@ def test_parse_argstring_or_throw(): def test_records_to_dataframe_missing_value_first(): result = """{"z":100, "y":50} {"z":25, "nullv":1.0, "y":10}""" - + df = records_to_dataframe(result, SESSION_KIND_PYSPARK, True) - expected = pd.DataFrame([{'z': 100, "nullv": None, 'y': 50}, {'z':25, "nullv":1, 'y':10}], columns=['z', "nullv", 'y']) + expected = pd.DataFrame( + [{"z": 100, "nullv": None, "y": 50}, {"z": 25, "nullv": 1, "y": 10}], + columns=["z", "nullv", "y"], + ) assert_frame_equal(expected, df) def test_records_to_dataframe_coercing(): result = """{"z":"100", "y":"2016-01-01"} {"z":"25", "y":"2016-01-01"}""" - + df = records_to_dataframe(result, SESSION_KIND_PYSPARK, True) - expected = pd.DataFrame([{'z': 100, 'y': np.datetime64("2016-01-01")}, {'z':25, 'y':np.datetime64("2016-01-01")}], columns=['z', 'y']) + expected = pd.DataFrame( + [ + {"z": 100, "y": np.datetime64("2016-01-01")}, + {"z": 25, "y": np.datetime64("2016-01-01")}, + ], + columns=["z", "y"], + ) assert_frame_equal(expected, df) def test_records_to_dataframe_no_coercing(): result = """{"z":"100", "y":"2016-01-01"} {"z":"25", "y":"2016-01-01"}""" - + df = records_to_dataframe(result, SESSION_KIND_PYSPARK, False) - expected = pd.DataFrame([{'z': "100", 'y': "2016-01-01"}, {'z':"25", 'y':"2016-01-01"}], columns=['z', 'y']) + expected = pd.DataFrame( + [{"z": "100", "y": "2016-01-01"}, {"z": "25", "y": "2016-01-01"}], + columns=["z", "y"], + ) assert_frame_equal(expected, df) def test_records_to_dataframe_missing_value_later(): result = """{"z":25, "nullv":1.0, "y":10} {"z":100, "y":50}""" - + df = records_to_dataframe(result, SESSION_KIND_PYSPARK, True) - expected = pd.DataFrame([{'z':25, "nullv":1, 'y':10}, {'z': 100, "nullv": None, 'y': 50}], columns=['z', "nullv", 'y']) + expected = pd.DataFrame( + [{"z": 25, "nullv": 1, "y": 10}, {"z": 100, "nullv": None, "y": 50}], + columns=["z", "nullv", "y"], + ) assert_frame_equal(expected, df) @@ -80,9 +105,9 @@ def test_dataframe_component(self): """ dc = DataframeHtmlParser(cell) rows = dc.row_iter() - self.assertDictEqual(next(rows), {'id': '1', 'animal': 'bat'}) - self.assertDictEqual(next(rows), {'id': '2', 'animal': 'mouse'}) - self.assertDictEqual(next(rows), {'id': '3', 'animal': 'horse'}) + self.assertDictEqual(next(rows), {"id": "1", "animal": "bat"}) + self.assertDictEqual(next(rows), {"id": "2", "animal": "mouse"}) + self.assertDictEqual(next(rows), {"id": "3", "animal": "horse"}) with self.assertRaises(StopIteration): next(rows) @@ -93,7 +118,7 @@ def test_dataframe_component(self): +---+------+ """ dc = DataframeHtmlParser(cell) - rows = dc.row_iter() + rows = dc.row_iter() with self.assertRaises(StopIteration): next(rows) @@ -109,13 +134,13 @@ def test_dataframe_component(self): Only showing the last 20 rows """ dc = DataframeHtmlParser(cell) - rows = dc.row_iter() - self.assertDictEqual(next(rows), {'id': '1', 'animal': 'bat'}) + rows = dc.row_iter() + self.assertDictEqual(next(rows), {"id": "1", "animal": "bat"}) with self.assertRaises(ValueError): next(rows) def test_dataframe_parsing(self): - + cell = """+---+------+ | id|animal| +---+------+ @@ -139,8 +164,9 @@ def test_dataframe_parsing(self): Only showing the last 20 rows """ - self.assertTrue(cell_contains_dataframe(cell), "Matches with leading whitespaces") - + self.assertTrue( + cell_contains_dataframe(cell), "Matches with leading whitespaces" + ) cell = """ +---+------+ @@ -160,7 +186,9 @@ def test_dataframe_parsing(self): +---+------+ """ - self.assertTrue(cell_contains_dataframe(cell), "Cell contains multiple dataframes") + self.assertTrue( + cell_contains_dataframe(cell), "Cell contains multiple dataframes" + ) cell = """ +---+------+ @@ -190,7 +218,7 @@ def test_dataframe_parsing(self): Only showing the last 20 rows """ self.assertFalse(cell_contains_dataframe(cell), "Footer contains a /") - + def test_cell_components_iter(self): cell = """+---+------+ | id|animal| @@ -203,7 +231,7 @@ def test_cell_components_iter(self): Only showing the last 20 rows """ df_spans = cell_components_iter(cell) - + self.assertEqual(next(df_spans), (CellComponentType.DF, 0, 90)) self.assertEqual(next(df_spans), (CellComponentType.TEXT, 90, 143)) self.assertEqual(cell[90:143].strip(), "Only showing the last 20 rows") @@ -235,7 +263,7 @@ def test_cell_components_iter(self): Random stuff at the end """ df_spans = cell_components_iter(cell) - + self.assertEqual(next(df_spans), (CellComponentType.TEXT, 0, 45)) self.assertEqual(cell[0:45].strip(), "Random stuff at the start") self.assertEqual(next(df_spans), (CellComponentType.DF, 45, 247)) @@ -247,7 +275,7 @@ def test_cell_components_iter(self): self.assertEqual(next(df_spans), (CellComponentType.TEXT, 495, 553)) self.assertEqual(cell[495:553].strip(), "Random stuff at the end") - + with self.assertRaises(StopIteration): next(df_spans) @@ -257,7 +285,7 @@ def test_cell_components_iter(self): """ df_spans = cell_components_iter(cell) - + self.assertEqual(next(df_spans), (CellComponentType.TEXT, 0, len(cell))) with self.assertRaises(StopIteration): next(df_spans) @@ -272,7 +300,9 @@ def test_output_html_parser(self): self.assertEqual(CellOutputHtmlParser.to_html(cell), "") cell = "Some random text" - self.assertEqual(CellOutputHtmlParser.to_html(cell), "
Some random text
") + self.assertEqual( + CellOutputHtmlParser.to_html(cell), "
Some random text
" + ) cell = """ @@ -298,18 +328,23 @@ def test_output_html_parser(self): Random stuff at the end """ self.maxDiff = 1200 - self.assertEqual(CellOutputHtmlParser.to_html(cell), - ("
Random stuff at the start

" - "" - "" - "" - "" - "
idanimal
1bat
2mouse
3horse

Random stuff in the middle

" - "" - "" - "" - "" - "
idanimal
1cat
2couse
3corse

Random stuff at the end
")) - -if __name__ == '__main__': + self.assertEqual( + CellOutputHtmlParser.to_html(cell), + ( + "
Random stuff at the start

" + "" + "" + "" + "" + "
idanimal
1bat
2mouse
3horse

Random stuff in the middle

" + "" + "" + "" + "" + "
idanimal
1cat
2couse
3corse

Random stuff at the end
" + ), + ) + + +if __name__ == "__main__": unittest.main() diff --git a/sparkmagic/sparkmagic/utils/configuration.py b/sparkmagic/sparkmagic/utils/configuration.py index 515492b43..68a9a95ec 100644 --- a/sparkmagic/sparkmagic/utils/configuration.py +++ b/sparkmagic/sparkmagic/utils/configuration.py @@ -2,18 +2,33 @@ import copy import sys import base64 -from hdijupyterutils.constants import EVENTS_HANDLER_CLASS_NAME, LOGGING_CONFIG_CLASS_NAME +from hdijupyterutils.constants import ( + EVENTS_HANDLER_CLASS_NAME, + LOGGING_CONFIG_CLASS_NAME, +) from hdijupyterutils.utils import join_paths from hdijupyterutils.configuration import override as _override from hdijupyterutils.configuration import override_all as _override_all from hdijupyterutils.configuration import with_override -from .constants import HOME_PATH, CONFIG_FILE, MAGICS_LOGGER_NAME, LIVY_KIND_PARAM, \ - LANG_SCALA, LANG_PYTHON, LANG_R, \ - SESSION_KIND_SPARKR, SESSION_KIND_SPARK, SESSION_KIND_PYSPARK, CONFIGURABLE_RETRY, \ - NO_AUTH, AUTH_BASIC +from .constants import ( + HOME_PATH, + CONFIG_FILE, + MAGICS_LOGGER_NAME, + LIVY_KIND_PARAM, + LANG_SCALA, + LANG_PYTHON, + LANG_R, + SESSION_KIND_SPARKR, + SESSION_KIND_SPARK, + SESSION_KIND_PYSPARK, + CONFIGURABLE_RETRY, + NO_AUTH, + AUTH_BASIC, +) from sparkmagic.livyclientlib.exceptions import BadUserConfigurationException -#import sparkmagic.utils.constants as constants + +# import sparkmagic.utils.constants as constants from requests_kerberos import REQUIRED @@ -22,7 +37,7 @@ d = {} path = join_paths(HOME_PATH, CONFIG_FILE) - + def override(config, value): _override(d, path, config, value) @@ -36,6 +51,7 @@ def override_all(obj): # Helpers + def get_livy_kind(language): if language == LANG_SCALA: return SESSION_KIND_SPARK @@ -44,26 +60,29 @@ def get_livy_kind(language): elif language == LANG_R: return SESSION_KIND_SPARKR else: - raise BadUserConfigurationException("Cannot get session kind for {}.".format(language)) + raise BadUserConfigurationException( + "Cannot get session kind for {}.".format(language) + ) def get_auth_value(username, password): - if username == '' and password == '': + if username == "" and password == "": return NO_AUTH return AUTH_BASIC @_with_override def authenticators(): - return { - u"Kerberos": u"sparkmagic.auth.kerberos.Kerberos", - u"None": u"sparkmagic.auth.customauth.Authenticator", - u"Basic_Access": u"sparkmagic.auth.basic.Basic" + return { + "Kerberos": "sparkmagic.auth.kerberos.Kerberos", + "None": "sparkmagic.auth.customauth.Authenticator", + "Basic_Access": "sparkmagic.auth.basic.Basic", } - - + + # Configs + def get_session_properties(language): properties = copy.deepcopy(session_configs()) properties[LIVY_KIND_PARAM] = get_livy_kind(language) @@ -77,9 +96,14 @@ def session_configs(): @_with_override def kernel_python_credentials(): - return {u'username': u'', u'base64_password': u'', u'url': u'http://localhost:8998', u'auth': NO_AUTH} - - + return { + "username": "", + "base64_password": "", + "url": "http://localhost:8998", + "auth": NO_AUTH, + } + + def base64_kernel_python_credentials(): return _credentials_override(kernel_python_credentials) @@ -96,15 +120,26 @@ def base64_kernel_python3_credentials(): @_with_override def kernel_scala_credentials(): - return {u'username': u'', u'base64_password': u'', u'url': u'http://localhost:8998', u'auth': NO_AUTH} + return { + "username": "", + "base64_password": "", + "url": "http://localhost:8998", + "auth": NO_AUTH, + } -def base64_kernel_scala_credentials(): +def base64_kernel_scala_credentials(): return _credentials_override(kernel_scala_credentials) + @_with_override def kernel_r_credentials(): - return {u'username': u'', u'base64_password': u'', u'url': u'http://localhost:8998', u'auth': NO_AUTH} + return { + "username": "", + "base64_password": "", + "url": "http://localhost:8998", + "auth": NO_AUTH, + } def base64_kernel_r_credentials(): @@ -114,27 +149,27 @@ def base64_kernel_r_credentials(): @_with_override def logging_config(): return { - u"version": 1, - u"formatters": { - u"magicsFormatter": { - u"format": u"%(asctime)s\t%(levelname)s\t%(message)s", - u"datefmt": u"" + "version": 1, + "formatters": { + "magicsFormatter": { + "format": "%(asctime)s\t%(levelname)s\t%(message)s", + "datefmt": "", } }, - u"handlers": { - u"magicsHandler": { - u"class": LOGGING_CONFIG_CLASS_NAME, - u"formatter": u"magicsFormatter", - u"home_path": HOME_PATH + "handlers": { + "magicsHandler": { + "class": LOGGING_CONFIG_CLASS_NAME, + "formatter": "magicsFormatter", + "home_path": HOME_PATH, } }, - u"loggers": { + "loggers": { MAGICS_LOGGER_NAME: { - u"handlers": [u"magicsHandler"], - u"level": u"DEBUG", - u"propagate": 0 + "handlers": ["magicsHandler"], + "level": "DEBUG", + "propagate": 0, } - } + }, } @@ -155,7 +190,7 @@ def livy_session_startup_timeout_seconds(): @_with_override def fatal_error_suggestion(): - return u"""The code failed because of a fatal error: + return """The code failed because of a fatal error: \t{}. Some things to try: @@ -201,7 +236,7 @@ def default_samplefraction(): @_with_override def pyspark_dataframe_encoding(): - return u'utf-8' + return "utf-8" @_with_override @@ -267,9 +302,7 @@ def cleanup_all_sessions_on_exit(): @_with_override def kerberos_auth_configuration(): - return { - "mutual_authentication": REQUIRED - } + return {"mutual_authentication": REQUIRED} def _credentials_override(f): @@ -278,15 +311,26 @@ def _credentials_override(f): If 'base64_password' is not set, it will fallback to 'password' in config. """ credentials = f() - base64_decoded_credentials = {k: credentials.get(k) for k in ('username', 'password', 'url', 'auth')} - base64_password = credentials.get('base64_password') + base64_decoded_credentials = { + k: credentials.get(k) for k in ("username", "password", "url", "auth") + } + base64_password = credentials.get("base64_password") if base64_password is not None: try: - base64_decoded_credentials['password'] = base64.b64decode(base64_password).decode() + base64_decoded_credentials["password"] = base64.b64decode( + base64_password + ).decode() except Exception: exception_type, exception, traceback = sys.exc_info() - msg = "base64_password for %s contains invalid base64 string: %s %s" % (f.__name__, exception_type, exception) + msg = "base64_password for %s contains invalid base64 string: %s %s" % ( + f.__name__, + exception_type, + exception, + ) raise BadUserConfigurationException(msg) - if base64_decoded_credentials['auth'] is None: - base64_decoded_credentials['auth'] = get_auth_value(base64_decoded_credentials['username'], base64_decoded_credentials['password']) + if base64_decoded_credentials["auth"] is None: + base64_decoded_credentials["auth"] = get_auth_value( + base64_decoded_credentials["username"], + base64_decoded_credentials["password"], + ) return base64_decoded_credentials diff --git a/sparkmagic/sparkmagic/utils/constants.py b/sparkmagic/sparkmagic/utils/constants.py index 10fc7e926..59b7c099c 100644 --- a/sparkmagic/sparkmagic/utils/constants.py +++ b/sparkmagic/sparkmagic/utils/constants.py @@ -9,7 +9,11 @@ SESSION_KIND_SPARK = "spark" SESSION_KIND_PYSPARK = "pyspark" SESSION_KIND_SPARKR = "sparkr" -SESSION_KINDS_SUPPORTED = [SESSION_KIND_SPARK, SESSION_KIND_PYSPARK, SESSION_KIND_SPARKR] +SESSION_KINDS_SUPPORTED = [ + SESSION_KIND_SPARK, + SESSION_KIND_PYSPARK, + SESSION_KIND_SPARKR, +] LIBRARY_LOADED_EVENT = "notebookLoaded" CLUSTER_CHANGE_EVENT = "notebookClusterChange" @@ -72,34 +76,57 @@ KILLED_SESSION_STATUS = "killed" RECOVERING_SESSION_STATUS = "recovering" -POSSIBLE_SESSION_STATUS = [NOT_STARTED_SESSION_STATUS, IDLE_SESSION_STATUS, STARTING_SESSION_STATUS, - BUSY_SESSION_STATUS, ERROR_SESSION_STATUS, DEAD_SESSION_STATUS, - SUCCESS_SESSION_STATUS, SHUT_DOWN_SESSION_STATUS, RUNNING_SESSION_STATUS, - KILLED_SESSION_STATUS, RECOVERING_SESSION_STATUS] -FINAL_STATUS = [DEAD_SESSION_STATUS, ERROR_SESSION_STATUS, SUCCESS_SESSION_STATUS, - KILLED_SESSION_STATUS] +POSSIBLE_SESSION_STATUS = [ + NOT_STARTED_SESSION_STATUS, + IDLE_SESSION_STATUS, + STARTING_SESSION_STATUS, + BUSY_SESSION_STATUS, + ERROR_SESSION_STATUS, + DEAD_SESSION_STATUS, + SUCCESS_SESSION_STATUS, + SHUT_DOWN_SESSION_STATUS, + RUNNING_SESSION_STATUS, + KILLED_SESSION_STATUS, + RECOVERING_SESSION_STATUS, +] +FINAL_STATUS = [ + DEAD_SESSION_STATUS, + ERROR_SESSION_STATUS, + SUCCESS_SESSION_STATUS, + KILLED_SESSION_STATUS, +] ERROR_STATEMENT_STATUS = "error" CANCELLED_STATEMENT_STATUS = "cancelled" AVAILABLE_STATEMENT_STATUS = "available" -FINAL_STATEMENT_STATUS = [ERROR_STATEMENT_STATUS, CANCELLED_STATEMENT_STATUS, AVAILABLE_STATEMENT_STATUS] +FINAL_STATEMENT_STATUS = [ + ERROR_STATEMENT_STATUS, + CANCELLED_STATEMENT_STATUS, + AVAILABLE_STATEMENT_STATUS, +] DELETE_SESSION_ACTION = "delete" START_SESSION_ACTION = "start" DO_NOTHING_ACTION = "nothing" -INTERNAL_ERROR_MSG = "An internal error was encountered.\n" \ - "Please file an issue at https://github.com/jupyter-incubator/sparkmagic\nError:\n{}" +INTERNAL_ERROR_MSG = ( + "An internal error was encountered.\n" + "Please file an issue at https://github.com/jupyter-incubator/sparkmagic\nError:\n{}" +) EXPECTED_ERROR_MSG = "An error was encountered:\n{}" YARN_RESOURCE_LIMIT_MSG = "Queue's AM resource limit exceeded." -RESOURCE_LIMIT_WARNING = "Warning: The Spark session does not have enough YARN resources to start. {}" +RESOURCE_LIMIT_WARNING = ( + "Warning: The Spark session does not have enough YARN resources to start. {}" +) COMMAND_INTERRUPTED_MSG = "Interrupted by user" -COMMAND_CANCELLATION_FAILED_MSG = "Interrupted by user but Livy failed to cancel the Spark statement. "\ - "The Livy session might have become unusable." +COMMAND_CANCELLATION_FAILED_MSG = ( + "Interrupted by user but Livy failed to cancel the Spark statement. " + "The Livy session might have become unusable." +) -LIVY_HEARTBEAT_TIMEOUT_PARAM = u"heartbeatTimeoutInSecond" -LIVY_KIND_PARAM = u"kind" +LIVY_HEARTBEAT_TIMEOUT_PARAM = "heartbeatTimeoutInSecond" +LIVY_KIND_PARAM = "kind" NO_AUTH = "None" AUTH_KERBEROS = "Kerberos" diff --git a/sparkmagic/sparkmagic/utils/dataframe_parser.py b/sparkmagic/sparkmagic/utils/dataframe_parser.py index 993fb4882..03d842f88 100644 --- a/sparkmagic/sparkmagic/utils/dataframe_parser.py +++ b/sparkmagic/sparkmagic/utils/dataframe_parser.py @@ -1,6 +1,6 @@ import re -import itertools -from collections import OrderedDict +import itertools +from collections import OrderedDict from enum import Enum from functools import partial @@ -33,22 +33,28 @@ """ -header_top_pattern = r'[ \t\f\v]*(?P\+[-+]*\+)[\n\r]?' -header_content_pattern = r'[ \t\f\v]*(?P\|.*\|)[\n\r]' -header_bottom_pattern = r'(?P[ \t\f\v]*\+[-+]*\+[\n\r])' -row_content_pattern = r'([ \t\f\v]*\|.*\|[\n\r])*' -footer_pattern = r'(?P