Skip to content
Merged
198 changes: 143 additions & 55 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
__all__ = ['Table', 'Q']


import abc
import collections
import collections.abc
import functools
import inspect
import itertools
import operator
import random
import urllib.parse

import numpy as np
Expand All @@ -22,25 +21,33 @@
import datascience.util as _util



class _Taker:
class _RowSelector(metaclass=abc.ABCMeta):
def __init__(self, table):
self._table = table

def __call__(self, row_numbers_or_slice):
"""Return a Table of a sequence of rows taken by number.
return self[row_numbers_or_slice]

@abc.abstractmethod
def __getitem__(self, item):
raise NotImplementedError()


class _RowTaker(_RowSelector):
def __getitem__(self, row_indices_or_slice):
"""Return a new Table of a sequence of rows taken by number.

Args:
``row_numbers_or_slice`` (slice or integer or list of integers):
The list of row numbers or a slice to be selected.
``row_indices_or_slice`` (integer or list of integers or slice):
The row index, list of row indices or a slice of row indices to
be selected.

Returns:
A ``Table`` containing only the selected rows.
A new instance of ``Table``.

>>> grade = ['A+', 'A', 'A-', 'B+', 'B', 'B-']
>>> gpa = [4, 4, 3.7, 3.3, 3, 2.7]
>>> t = Table([grade, gpa], ['letter grade', 'gpa'])

>>> t
letter grade | gpa
A+ | 4
Expand All @@ -63,7 +70,7 @@ def __call__(self, row_numbers_or_slice):
A- | 3.7
A | 4
A+ | 4
>>> print(t.take([1, 5]))
>>> t.take([1, 5])
letter grade | gpa
A | 4
B- | 2.7
Expand All @@ -88,23 +95,92 @@ def __call__(self, row_numbers_or_slice):
A+ | 4

"""
return self[row_numbers_or_slice]

def __getitem__(self, i):
if isinstance(i, collections.Iterable):
columns = [np.take(column, i, axis=0)
if isinstance(row_indices_or_slice, collections.Iterable):
columns = [np.take(column, row_indices_or_slice, axis=0)
for column in self._table._columns.values()]
return self._table._with_columns(columns)
elif isinstance(i, slice):
columns = [column[i] for column in self._table._columns.values()]
return self._table._with_columns(columns)
else:
rows = self._table.rows[i]
cols = self._table._columns.keys()
if not isinstance(rows, list):
rows = [rows]

return Table.from_rows(rows, cols)
rows = self._table.rows[row_indices_or_slice]
if isinstance(rows, Table.Row):
rows = [rows]
return Table.from_rows(rows, self._table.column_labels)


class _RowExcluder(_RowSelector):
def __getitem__(self, row_indices_or_slice):
"""Return a new Table without a sequence of rows excluded by number.

Args:
``row_indices_or_slice`` (integer or list of integers or slice):
The row index, list of row indices or a slice of row indices
to be excluded.

Returns:
A new instance of ``Table``.

>>> grade = ['A+', 'A', 'A-', 'B+', 'B', 'B-']
>>> gpa = [4, 4, 3.7, 3.3, 3, 2.7]
>>> t = Table([grade, gpa], ['letter grade', 'gpa'])
>>> t
letter grade | gpa
A+ | 4
A | 4
A- | 3.7
B+ | 3.3
B | 3
B- | 2.7
>>> t.exclude(4)
letter grade | gpa
A+ | 4
A | 4
A- | 3.7
B+ | 3.3
B- | 2.7
>>> t.exclude(-1)
letter grade | gpa
A+ | 4
A | 4
A- | 3.7
B+ | 3.3
B | 3
>>> t.exclude([1, 3, 4])
letter grade | gpa
A+ | 4
A- | 3.7
B- | 2.7
>>> t.exclude(range(3))
letter grade | gpa
B+ | 3.3
B | 3
B- | 2.7

Note that ``exclude`` also supports NumPy-like indexing and slicing:

>>> t.exclude[:3]
letter grade | gpa
B+ | 3.3
B | 3
B- | 2.7

>>> t.exclude[1, 3, 4]
letter grade | gpa
A+ | 4
A- | 3.7
B- | 2.7
"""
if isinstance(row_indices_or_slice, collections.Iterable):
without_row_indices = set(row_indices_or_slice)
rows = [row for index, row in enumerate(self._table.rows[:])
if index not in without_row_indices]
return Table.from_rows(rows, self._table.column_labels)

row_slice = row_indices_or_slice
if not isinstance(row_slice, slice):
row_slice %= self._table.num_rows
row_slice = slice(row_slice, row_slice+1)
return Table.from_rows(itertools.chain(self._table.rows[:row_slice.start or 0],
self._table.rows[row_slice.stop:]),
self._table.column_labels)


class Table(collections.abc.MutableMapping):
Expand Down Expand Up @@ -167,12 +243,16 @@ def __init__(self, columns=None, labels=None,
for column, label in zip(columns, labels):
self[label] = column

self.take = _Taker(self)
self.take = _RowTaker(self)
self.exclude = _RowExcluder(self)

# These, along with a snippet below, are necessary for Sphinx to
# correctly load the `take` and `exclude` docstrings. The definitions
# will be over-ridden during class instantiation.
def take(self):
raise NotImplementedError()

# This, along with a snippet below, is necessary for Sphinx to
# correctly load the `take` docstring. The definition will be
# over-ridden during class instantiation.
def take(cls):
def exclude(self):
raise NotImplementedError()

@classmethod
Expand Down Expand Up @@ -493,7 +573,6 @@ def append_column(self, label, values):

self._columns[label] = values


def relabel(self, column_label, new_label):
"""Change the labels of columns specified by ``column_label`` to
labels in ``new_label``.
Expand Down Expand Up @@ -599,6 +678,7 @@ def drop(self, column_label_or_labels):
"""Return a Table with only columns other than selected label or labels."""
exclude = _as_labels(column_label_or_labels)
return self.select([c for c in self.column_labels if c not in exclude])

def where(self, column_or_label, value=None):
"""Return a Table of rows for which the column is value or a non-zero value."""
column = self._get_column(column_or_label)
Expand Down Expand Up @@ -732,7 +812,7 @@ def pivot_bin(self, pivot_columns, value_column, bins=None, **vargs) :
# refine bins by taking a histogram over all the data
if bins is not None:
vargs['bins'] = bins
_,rbins = np.histogram(self[value_column],**vargs)
_, rbins = np.histogram(self[value_column],**vargs)
# create a table with these bins a first column and counts for each group
vargs['bins'] = rbins
binned = Table([rbins],['bin'])
Expand Down Expand Up @@ -937,7 +1017,7 @@ def split(self, k):
"""
if not 1 <= k <= self.num_rows - 1:
raise ValueError("Invalid value of k. k must be between 1 and the"
"number of rows - 1")
"number of rows - 1")

rows = [self.rows[index] for index in
np.random.permutation(self.num_rows)]
Expand All @@ -946,7 +1026,7 @@ def split(self, k):
for column_label in self._formats :
first._formats[column_label] = self._formats[column_label]
rest._formats[column_label] = self._formats[column_label]
return first,rest
return first, rest

def with_column(self, label, values):
"""Returns a new table with new column included.
Expand Down Expand Up @@ -1236,11 +1316,13 @@ def plot(self, column_for_xticks=None, overlay=False, **vargs):
options = self.default_options.copy()
options.update(vargs)
xticks, labels = self._split_by_column(column_for_xticks)

def draw(axis, label, color):
if xticks is None:
axis.plot(self[label], color=color, **options)
else :
axis.plot(xticks, self[label], color=color, **options)

def annotate(axis, ticks):
tick_labels = [ticks[int(l)] for l in axis.get_xticks() if l<len(ticks)]
axis.set_xticklabels(tick_labels, rotation='vertical')
Expand All @@ -1252,12 +1334,14 @@ def scatter(self, column_for_x, overlay=False, fit_line=False, **vargs):
options = self.default_options.copy()
options.update(vargs)
xdata, labels = self._split_by_column(column_for_x)

def draw(axis, label, color):
axis.scatter(xdata, self[label], color=color, **options)
if fit_line:
m,b = np.polyfit(xdata, self[label],1)
minx, maxx = np.min(xdata),np.max(xdata)
axis.plot([minx,maxx],[m*minx+b,m*maxx+b])

def annotate(axis, ticks):
return None
self._visualize(labels, None, overlay, draw, annotate)
Expand Down Expand Up @@ -1348,16 +1432,18 @@ def barh(self, column_for_categories, overlay=False, **vargs):
width = 1 - 2 * margin
if overlay:
width /= len(labels)

def draw(axis, label, color):
if overlay:
ypos = index + margin + (1-2*margin)*labels.index(label)/len(labels)
else:
ypos = index
#barh plots entries in reverse order from bottom to top
# barh plots entries in reverse order from bottom to top
axis.barh(ypos, self[label][::-1], width, color=color, **options)

def annotate(axis, ticks):
axis.set_yticks(index+0.5) # Center labels on bars
#barh plots entries in reverse order from bottom to top
# barh plots entries in reverse order from bottom to top
axis.set_yticklabels(ticks[::-1], stretch='ultra-condensed')
height = max(4, len(index)/2)
if 'height' in vargs:
Expand Down Expand Up @@ -1448,12 +1534,14 @@ def bar(self, column_for_categories=None, overlay=False, **vargs):
width = 1 - 2 * margin
if overlay:
width /= len(labels)

def draw(axis, label, color):
if overlay:
xpos = index + margin + (1-2*margin)*labels.index(label)/len(labels)
else:
xpos = index
axis.bar(xpos, self[label], 1.0, color=color, **options)

def annotate(axis, ticks):
if (xticks is not None) and (len(xticks) < 10) :
axis.set_xticks(index+0.5) # Center labels on bars
Expand Down Expand Up @@ -1632,6 +1720,19 @@ def points(self, column__lat, column__long, labels=None, colors=None, **kwargs)
# Support #
###########

class Row(tuple):
_table = None # Set by subclasses in Rows

def __getattr__(self, column_label):
return self[self._table.column_index(column_label)]

def __repr__(self):
return 'Row({})'.format(', '.join('{}={}'.format(
self._table.column_labels[i], v.__repr__()) for i, v in enumerate(self)))

def asdict(self):
return collections.OrderedDict(zip(self._table.column_labels, self))

class Rows(collections.abc.Sequence):
"""An iterable view over the rows in a table."""
def __init__(self, table):
Expand All @@ -1640,27 +1741,13 @@ def __init__(self, table):

def __getitem__(self, i):
if isinstance(i, slice):
return [self[j] for j in range(*i.indices(len(self)))]
return (self[j] for j in range(*i.indices(len(self))))

labels = tuple(self._table.column_labels)
if labels != self._labels:
self._labels = labels

class Row(tuple):
__table = self._table

def __getattr__(self, column_label):
return self[self.__table.column_index(column_label)]

def __repr__(self):
return 'Row({})'.format(', '.join('{}={}'.format(
self.__table.column_labels[i], v.__repr__()) for i, v in enumerate(self)))

def asdict(self):
return collections.OrderedDict(zip(self.__table.column_labels, self))

self._row = Row

return self._row([c[i] for c in self._table._columns.values()])
self._row = type('Row', (Table.Row, ), dict(_table=self._table))
return self._row(c[i] for c in self._table._columns.values())

def __len__(self):
return self._table.num_rows
Expand All @@ -1669,8 +1756,9 @@ def __repr__(self):
return '{0}({1})'.format(type(self).__name__, repr(self._table))


# For Sphinx: grab the docstring from `Taker.__call__`
Table.take.__doc__ = _Taker.__call__.__doc__
# For Sphinx: grab the docstrings from `Taker.__getitem__` and `Withouter.__getitem__`
Table.take.__doc__ = _RowTaker.__getitem__.__doc__
Table.exclude.__doc__ = _RowExcluder.__getitem__.__doc__


class Q:
Expand Down
Loading