diff --git a/datascience/tables.py b/datascience/tables.py index fe9730da6..2376bb92b 100644 --- a/datascience/tables.py +++ b/datascience/tables.py @@ -1334,7 +1334,69 @@ def annotate(axis, ticks): self._visualize(labels, xticks, overlay, draw, annotate) def scatter(self, column_for_x, overlay=False, fit_line=False, **vargs): - """Plot contents as lines.""" + """Creates scatterplots, optionally adding a line of best fit. + + All scatterplots use the values in ``column_for_x`` as the x-values. A + total of n - 1 scatterplots are created where n is the number of + columns in the table, one for every column other than ``column_for_x``. + + Requires all columns in the table to contain numerical values only. + If the columns contain other types, a ``ValueError`` is raised. + + Args: + ``column_for_x`` (str): The name to use for the x-axis values of the + scatter plots. + + Kwargs: + ``overlay`` (bool): If True, creates one scatterplot with n - 1 + y-values plotted, one for each column other than + ``column_for_x`` (instead of the default behavior of creating n + - 1 scatterplots. Also adds a legend that matches each dot + and best-fit line color to its column. + + ``fit_line`` (bool): If True, draws a line of best fit for each + scatterplot drawn. + + ``vargs``: Additional arguments that get passed into `plt.scatter`. + See http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter + for additional arguments that can be passed into vargs. These + include: `marker` and `norm`, to name a couple. + + Returns: + None + + Raises: + ``ValueError``: The table contains non-numerical values in columns. + + >>> x = [9, 3, 3, 1] + >>> y = [1, 2, 2, 10] + >>> z = [3, 4, 5, 6] + >>> table = Table([x, y, z], ['x', 'y', 'z']) + >>> table + x | y | z + 9 | 1 | 3 + 3 | 2 | 4 + 3 | 2 | 5 + 1 | 10 | 6 + >>> table.scatter('x') # doctest: +SKIP + + + + >>> table.scatter('x', overlay = True) # doctest: +SKIP + + + >>> table.scatter('x', fit_line = True) # doctest: +SKIP + + + + """ + # Check for non-numerical values and raise a ValueError if any found + for col in self: + if any(isinstance(cell, np.flexible) for cell in self[col]): + raise ValueError("The column '{0}' contains non-numerical " + "values. A histogram cannot be drawn for this table." + .format(col)) + options = self.default_options.copy() options.update(vargs) xdata, labels = self._split_by_column(column_for_x) @@ -1342,7 +1404,7 @@ def scatter(self, column_for_x, overlay=False, fit_line=False, **vargs): def draw(axis, label, color): axis.scatter(xdata, self[label], color=color, **options) if fit_line: - m,b = np.polyfit(xdata, self[label],1) + m,b = np.polyfit(xdata, self[label], 1) minx, maxx = np.min(xdata),np.max(xdata) axis.plot([minx,maxx],[m*minx+b,m*maxx+b]) diff --git a/docs/tables.rst b/docs/tables.rst index f1ebc98ac..a0579bd5a 100644 --- a/docs/tables.rst +++ b/docs/tables.rst @@ -119,7 +119,9 @@ Visualizations :toctree: _autosummary Table.plot + Table.bar Table.barh Table.pivot_hist Table.hist Table.points + Table.scatter diff --git a/tests/test_tables.py b/tests/test_tables.py index 9011f616f..f77ecdafe 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -884,7 +884,22 @@ def test_split_k_vals(table): # Visualize # ############# +def test_scatter(numbers_table): + """Tests that Table.scatter doesn't raise an error when the table doesn't + contains non-numerical values. Not working right now because of TKinter + issues on Travis. + TODO(sam): Fix Travis so this runs + """ + + # numbers_table.scatter('count') + +def test_scatter_error(table): + """Tests that Table.scatter raises an error when the table contains + non-numerical values.""" + + with pytest.raises(ValueError): + table.scatter('letter') ########### # Queries #