Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,15 +1334,77 @@ def annotate(axis, ticks):
self._visualize(labels, xticks, overlay, draw, annotate)

def scatter(self, column_for_x, overlay=False, fit_line=False, **vargs):
"""Plot contents as lines."""
"""Creates scatterplots, optionally adding a line of best fit.

All scatterplots use the values in ``column_for_x`` as the x-values. A
total of n - 1 scatterplots are created where n is the number of
columns in the table, one for every column other than ``column_for_x``.

Requires all columns in the table to contain numerical values only.
If the columns contain other types, a ``ValueError`` is raised.

Args:
``column_for_x`` (str): The name to use for the x-axis values of the
scatter plots.

Kwargs:
``overlay`` (bool): If True, creates one scatterplot with n - 1
y-values plotted, one for each column other than
``column_for_x`` (instead of the default behavior of creating n
- 1 scatterplots. Also adds a legend that matches each dot
and best-fit line color to its column.

``fit_line`` (bool): If True, draws a line of best fit for each
scatterplot drawn.

``vargs``: Additional arguments that get passed into `plt.scatter`.
See http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter
for additional arguments that can be passed into vargs. These
include: `marker` and `norm`, to name a couple.

Returns:
None

Raises:
``ValueError``: The table contains non-numerical values in columns.

>>> x = [9, 3, 3, 1]
>>> y = [1, 2, 2, 10]
>>> z = [3, 4, 5, 6]
>>> table = Table([x, y, z], ['x', 'y', 'z'])
>>> table
x | y | z
9 | 1 | 3
3 | 2 | 4
3 | 2 | 5
1 | 10 | 6
>>> table.scatter('x') # doctest: +SKIP
<scatterplot of values in y on x>
<scatterplot of values in z on x>

>>> table.scatter('x', overlay = True) # doctest: +SKIP
<scatterplot of values in y and z on x>

>>> table.scatter('x', fit_line = True) # doctest: +SKIP
<scatterplot of values in y on x with line of best fit>
<scatterplot of values in z on x with line of best fit>

"""
# Check for non-numerical values and raise a ValueError if any found
for col in self:
if any(isinstance(cell, np.flexible) for cell in self[col]):
raise ValueError("The column '{0}' contains non-numerical "
"values. A histogram cannot be drawn for this table."
.format(col))

options = self.default_options.copy()
options.update(vargs)
xdata, labels = self._split_by_column(column_for_x)

def draw(axis, label, color):
axis.scatter(xdata, self[label], color=color, **options)
if fit_line:
m,b = np.polyfit(xdata, self[label],1)
m,b = np.polyfit(xdata, self[label], 1)
minx, maxx = np.min(xdata),np.max(xdata)
axis.plot([minx,maxx],[m*minx+b,m*maxx+b])

Expand Down
2 changes: 2 additions & 0 deletions docs/tables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ Visualizations
:toctree: _autosummary

Table.plot
Table.bar
Table.barh
Table.pivot_hist
Table.hist
Table.points
Table.scatter
15 changes: 15 additions & 0 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,22 @@ def test_split_k_vals(table):
# Visualize #
#############

def test_scatter(numbers_table):
"""Tests that Table.scatter doesn't raise an error when the table doesn't
contains non-numerical values. Not working right now because of TKinter
issues on Travis.

TODO(sam): Fix Travis so this runs
"""

# numbers_table.scatter('count')

def test_scatter_error(table):
"""Tests that Table.scatter raises an error when the table contains
non-numerical values."""

with pytest.raises(ValueError):
table.scatter('letter')

###########
# Queries #
Expand Down