Skip to content

Commit

Permalink
Support sampling with replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Dec 29, 2015
1 parent 09a75c1 commit 62d3019
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dask/dataframe/core.py
Expand Up @@ -467,13 +467,15 @@ def fillna(self, value):
func = getattr(self._partition_type, 'fillna')
return map_partitions(func, self.column_info, self, value)

def sample(self, frac, random_state=None):
def sample(self, frac, replace=False, random_state=None):
""" Random sample of items
Parameters
----------
frac : float, optional
Fraction of axis items to return.
replace: boolean, optional
Sample with or without replacement. Default = False.
random_state: int or np.random.RandomState
If int create a new RandomState with this as the seed
Otherwise draw from the passed RandomState
Expand All @@ -487,14 +489,15 @@ def sample(self, frac, random_state=None):
if random_state is None:
random_state = np.random.randint(np.iinfo(np.int32).max)

name = 'sample-' + tokenize(self, frac, random_state)
name = 'sample-' + tokenize(self, frac, replace, random_state)
func = getattr(self._partition_type, 'sample')

seeds = different_seeds(self.npartitions, random_state)

dsk = dict(((name, i),
(apply, func, (tuple, [(self._name, i)]),
{'frac': frac, 'random_state': seed}))
{'frac': frac, 'random_state': seed,
'replace': replace}))
for i, seed in zip(range(self.npartitions), seeds))

return self._constructor(merge(self.dask, dsk), name,
Expand Down
9 changes: 9 additions & 0 deletions dask/dataframe/tests/test_dataframe.py
Expand Up @@ -2156,6 +2156,15 @@ def test_sample():
assert a.sample(0.5)._name != a.sample(0.5)._name


def test_sample_without_replacement():
df = pd.DataFrame({'x': [1, 2, 3, 4, None, 6], 'y': list('abdabd')},
index=[10, 20, 30, 40, 50, 60])
a = dd.from_pandas(df, 2)
b = a.sample(0.7, replace=False)
bb = b.index.compute()
assert len(bb) == len(set(bb))


def test_datetime_accessor():
df = pd.DataFrame({'x': [1, 2, 3, 4]})
df['x'] = df.x.astype('M8[us]')
Expand Down

0 comments on commit 62d3019

Please sign in to comment.