diff --git a/datascience/tables.py b/datascience/tables.py index 255fef27..2b9b6e34 100644 --- a/datascience/tables.py +++ b/datascience/tables.py @@ -550,7 +550,7 @@ def apply(self, fn, *column_or_columns): return np.array([fn(row) for row in self.rows]) else: if len(column_or_columns) == 1 and \ - _is_non_string_iterable(column_or_columns[0]): + _util.is_non_string_iterable(column_or_columns[0]): warnings.warn( "column lists are deprecated; pass each as an argument", FutureWarning) column_or_columns = column_or_columns[0] @@ -998,9 +998,14 @@ def append_column(self, label, values, formatter=None): if not isinstance(values, np.ndarray): # Coerce a single value to a sequence - if not _is_non_string_iterable(values): + if not _util.is_non_string_iterable(values): values = [values] * max(self.num_rows, 1) - values = np.array(tuple(values)) + + # Manually cast `values` as an object due to this: https://github.com/data-8/datascience/issues/458 + if any(_util.is_non_string_iterable(el) for el in values): + values = np.array(tuple(values), dtype=object) + else: + values = np.array(tuple(values)) if self.num_rows != 0 and len(values) != self.num_rows: raise ValueError('Column length mismatch. New column does not have ' @@ -1563,7 +1568,7 @@ def group(self, column_or_label, collect=None): Round | | 13 | 4.05 """ # Assume that a call to group with a list of labels is a call to groups - if _is_non_string_iterable(column_or_label) and \ + if _util.is_non_string_iterable(column_or_label) and \ len(column_or_label) != self._num_rows: return self.groups(column_or_label, collect) @@ -1647,7 +1652,7 @@ def groups(self, labels, collect=None): Red | Round | 11 | 3.05 """ # Assume that a call to groups with one label is a call to group - if not _is_non_string_iterable(labels): + if not _util.is_non_string_iterable(labels): return self.group(labels, collect=collect) collect = _zero_on_type_error(collect) @@ -2042,7 +2047,7 @@ def join(self, column_label, other, other_label=None): other_label = column_label # checking to see if joining on multiple columns - if _is_non_string_iterable(column_label): + if _util.is_non_string_iterable(column_label): # then we are going to be joining multiple labels return self._multiple_join(column_label, other, other_label) @@ -5734,7 +5739,7 @@ def _fill_with_zeros(partials, rows, zero=None): zero -- value used when no rows match a particular partial """ assert len(rows) > 0 - if not _is_non_string_iterable(partials): + if not _util.is_non_string_iterable(partials): # Convert partials to tuple for comparison against row slice later partials = [(partial,) for partial in partials] @@ -5753,7 +5758,7 @@ def _fill_with_zeros(partials, rows, zero=None): def _as_labels(column_or_columns): """Return a list of labels for a label or labels.""" - if not _is_non_string_iterable(column_or_columns): + if not _util.is_non_string_iterable(column_or_columns): return [column_or_columns] else: return column_or_columns @@ -5763,7 +5768,7 @@ def _varargs_labels_as_list(label_list): of labels.""" if len(label_list) == 0: return [] - elif not _is_non_string_iterable(label_list[0]): + elif not _util.is_non_string_iterable(label_list[0]): # Assume everything is a label. If not, it'll be caught later. return label_list elif len(label_list) == 1: @@ -5788,17 +5793,6 @@ def _collected_label(collect, label): else: return label - -def _is_non_string_iterable(value): - """Whether a value is iterable.""" - if isinstance(value, str): - return False - if hasattr(value, '__iter__'): - return True - if isinstance(value, collections.abc.Sequence): - return True - return False - def _vertical_x(axis, ticks=None, max_width=5): """Switch labels to vertical if they are long.""" if ticks is None: diff --git a/datascience/util.py b/datascience/util.py index 3ffdb895..9bc9d34c 100644 --- a/datascience/util.py +++ b/datascience/util.py @@ -2,7 +2,7 @@ __all__ = ['make_array', 'percentile', 'plot_cdf_area', 'plot_normal_cdf', 'table_apply', 'proportions_from_distribution', - 'sample_proportions', 'minimize'] + 'sample_proportions', 'minimize', 'is_non_string_iterable'] import numpy as np import pandas as pd @@ -13,6 +13,7 @@ from scipy import optimize import functools import math +import collections # Change matplotlib formatting. TODO incorporate into a style? plt.rcParams["patch.force_edgecolor"] = True @@ -37,6 +38,11 @@ def make_array(*elements): # Specifically added for Windows machines where the default # integer is int32 - see GH issue #339. return np.array(elements, dtype="int64") + + # Manually cast `elements` as an object due to this: https://github.com/data-8/datascience/issues/458 + if any(is_non_string_iterable(el) for el in elements): + return np.array(elements, dtype=object) + return np.array(elements) @@ -241,3 +247,13 @@ def objective(args): return result.x.item(0) else: return result.x + +def is_non_string_iterable(value): + """Whether a value is iterable.""" + if isinstance(value, str): + return False + if hasattr(value, '__iter__'): + return True + if isinstance(value, collections.abc.Sequence): + return True + return False \ No newline at end of file