Skip to content

Commit

Permalink
second version of code
Browse files Browse the repository at this point in the history
  • Loading branch information
adnanhemani committed Jun 12, 2022
1 parent 3d806a4 commit 70cb431
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
34 changes: 14 additions & 20 deletions datascience/tables.py
Expand Up @@ -550,7 +550,7 @@ def apply(self, fn, *column_or_columns):
return np.array([fn(row) for row in self.rows])
else:
if len(column_or_columns) == 1 and \
_is_non_string_iterable(column_or_columns[0]):
_util.is_non_string_iterable(column_or_columns[0]):
warnings.warn(
"column lists are deprecated; pass each as an argument", FutureWarning)
column_or_columns = column_or_columns[0]
Expand Down Expand Up @@ -998,9 +998,14 @@ def append_column(self, label, values, formatter=None):

if not isinstance(values, np.ndarray):
# Coerce a single value to a sequence
if not _is_non_string_iterable(values):
if not _util.is_non_string_iterable(values):
values = [values] * max(self.num_rows, 1)
values = np.array(tuple(values))

# Manually cast `values` as an object due to this: https://github.com/data-8/datascience/issues/458
if any(_util.is_non_string_iterable(el) for el in values):
values = np.array(tuple(values), dtype=object)
else:
values = np.array(tuple(values))

if self.num_rows != 0 and len(values) != self.num_rows:
raise ValueError('Column length mismatch. New column does not have '
Expand Down Expand Up @@ -1563,7 +1568,7 @@ def group(self, column_or_label, collect=None):
Round | | 13 | 4.05
"""
# Assume that a call to group with a list of labels is a call to groups
if _is_non_string_iterable(column_or_label) and \
if _util.is_non_string_iterable(column_or_label) and \
len(column_or_label) != self._num_rows:
return self.groups(column_or_label, collect)

Expand Down Expand Up @@ -1647,7 +1652,7 @@ def groups(self, labels, collect=None):
Red | Round | 11 | 3.05
"""
# Assume that a call to groups with one label is a call to group
if not _is_non_string_iterable(labels):
if not _util.is_non_string_iterable(labels):
return self.group(labels, collect=collect)

collect = _zero_on_type_error(collect)
Expand Down Expand Up @@ -2042,7 +2047,7 @@ def join(self, column_label, other, other_label=None):
other_label = column_label

# checking to see if joining on multiple columns
if _is_non_string_iterable(column_label):
if _util.is_non_string_iterable(column_label):
# then we are going to be joining multiple labels
return self._multiple_join(column_label, other, other_label)

Expand Down Expand Up @@ -5734,7 +5739,7 @@ def _fill_with_zeros(partials, rows, zero=None):
zero -- value used when no rows match a particular partial
"""
assert len(rows) > 0
if not _is_non_string_iterable(partials):
if not _util.is_non_string_iterable(partials):
# Convert partials to tuple for comparison against row slice later
partials = [(partial,) for partial in partials]

Expand All @@ -5753,7 +5758,7 @@ def _fill_with_zeros(partials, rows, zero=None):

def _as_labels(column_or_columns):
"""Return a list of labels for a label or labels."""
if not _is_non_string_iterable(column_or_columns):
if not _util.is_non_string_iterable(column_or_columns):
return [column_or_columns]
else:
return column_or_columns
Expand All @@ -5763,7 +5768,7 @@ def _varargs_labels_as_list(label_list):
of labels."""
if len(label_list) == 0:
return []
elif not _is_non_string_iterable(label_list[0]):
elif not _util.is_non_string_iterable(label_list[0]):
# Assume everything is a label. If not, it'll be caught later.
return label_list
elif len(label_list) == 1:
Expand All @@ -5788,17 +5793,6 @@ def _collected_label(collect, label):
else:
return label


def _is_non_string_iterable(value):
"""Whether a value is iterable."""
if isinstance(value, str):
return False
if hasattr(value, '__iter__'):
return True
if isinstance(value, collections.abc.Sequence):
return True
return False

def _vertical_x(axis, ticks=None, max_width=5):
"""Switch labels to vertical if they are long."""
if ticks is None:
Expand Down
18 changes: 17 additions & 1 deletion datascience/util.py
Expand Up @@ -2,7 +2,7 @@

__all__ = ['make_array', 'percentile', 'plot_cdf_area', 'plot_normal_cdf',
'table_apply', 'proportions_from_distribution',
'sample_proportions', 'minimize']
'sample_proportions', 'minimize', 'is_non_string_iterable']

import numpy as np
import pandas as pd
Expand All @@ -13,6 +13,7 @@
from scipy import optimize
import functools
import math
import collections

# Change matplotlib formatting. TODO incorporate into a style?
plt.rcParams["patch.force_edgecolor"] = True
Expand All @@ -37,6 +38,11 @@ def make_array(*elements):
# Specifically added for Windows machines where the default
# integer is int32 - see GH issue #339.
return np.array(elements, dtype="int64")

# Manually cast `elements` as an object due to this: https://github.com/data-8/datascience/issues/458
if any(is_non_string_iterable(el) for el in elements):
return np.array(elements, dtype=object)

return np.array(elements)


Expand Down Expand Up @@ -241,3 +247,13 @@ def objective(args):
return result.x.item(0)
else:
return result.x

def is_non_string_iterable(value):
"""Whether a value is iterable."""
if isinstance(value, str):
return False
if hasattr(value, '__iter__'):
return True
if isinstance(value, collections.abc.Sequence):
return True
return False

0 comments on commit 70cb431

Please sign in to comment.