Skip to content

Commit

Permalink
Fix deserializing RandomStateField when its value is None (#3149)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaokunyang committed Jun 16, 2022
1 parent 96ec5a9 commit 8ca0d84
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 162 deletions.
206 changes: 47 additions & 159 deletions mars/dataframe/groupby/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import itertools
import random
from collections.abc import Iterable
from typing import Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -74,99 +75,38 @@ class GroupBySampleILoc(DataFrameOperand, DataFrameOperandMixin):
_op_code_ = opcodes.GROUPBY_SAMPLE_ILOC
_op_module_ = "dataframe.groupby"

_groupby_params = DictField("groupby_params")
_size = Int64Field("size")
_frac = Float32Field("frac")
_replace = BoolField("replace")
_weights = KeyField("weights")
_seed = Int32Field("seed")
_random_state = RandomStateField("random_state")
_errors = StringField("errors")
groupby_params = DictField("groupby_params", default=None)
size = Int64Field("size", default=None)
frac = Float32Field("frac", default=None)
replace = BoolField("replace", default=None)
weights = KeyField("weights", default=None)
seed = Int32Field("seed", default=None)
_random_state = RandomStateField("random_state", default=None)
errors = StringField("errors", default=None)

_random_col_id = Int32Field("random_col_id")
random_col_id = Int32Field("random_col_id", default=None)

# for chunks
# num of instances for chunks
_left_iloc_bound = Int64Field("left_iloc_bound")

def __init__(
self,
groupby_params=None,
size=None,
frac=None,
replace=None,
weights=None,
random_state=None,
seed=None,
errors=None,
left_iloc_bound=None,
random_col_id=None,
**kw
):
super().__init__(
_groupby_params=groupby_params,
_size=size,
_frac=frac,
_seed=seed,
_replace=replace,
_weights=weights,
_random_state=random_state,
_errors=errors,
_left_iloc_bound=left_iloc_bound,
_random_col_id=random_col_id,
**kw
)
if self._random_col_id is None:
self._random_col_id = random.randint(10000, 99999)

@property
def groupby_params(self):
return self._groupby_params

@property
def size(self):
return self._size
left_iloc_bound = Int64Field("left_iloc_bound", default=None)

@property
def frac(self):
return self._frac

@property
def replace(self):
return self._replace

@property
def weights(self):
return self._weights

@property
def seed(self):
return self._seed
def __init__(self, random_state=None, **kw):
super().__init__(_random_state=random_state, **kw)
if self.random_col_id is None:
self.random_col_id = random.randint(10000, 99999)

@property
def random_state(self):
if self._random_state is None:
self._random_state = np.random.RandomState(self.seed)
return self._random_state

@property
def errors(self):
return self._errors

@property
def left_iloc_bound(self):
return self._left_iloc_bound

@property
def random_col_id(self):
return self._random_col_id

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
input_iter = iter(inputs)
next(input_iter)
if isinstance(self.weights, ENTITY_TYPE):
self._weights = next(input_iter)
self.weights = next(input_iter)

def __call__(self, df):
self._output_types = [OutputType.tensor]
Expand Down Expand Up @@ -211,9 +151,9 @@ def tile(cls, op: "GroupBySampleILoc"):
left_ilocs = np.array((0,) + in_df.nsplits[0]).cumsum()
for inp_chunk, weight_chunk in zip(in_df.chunks, weights_iter):
new_op = op.copy().reset_key()
new_op._left_iloc_bound = int(left_ilocs[inp_chunk.index[0]])
new_op.left_iloc_bound = int(left_ilocs[inp_chunk.index[0]])
new_op.stage = OperandStage.map
new_op._output_types = [OutputType.dataframe]
new_op.output_types = [OutputType.dataframe]

inp_chunks = [inp_chunk]
if weight_chunk is not None:
Expand Down Expand Up @@ -252,9 +192,9 @@ def tile(cls, op: "GroupBySampleILoc"):
for group_chunk, seed in zip(grouped.chunks, seeds):
new_op = op.copy().reset_key()
new_op.stage = OperandStage.reduce
new_op._weights = None
new_op.weights = None
new_op._random_state = None
new_op._seed = seed
new_op.seed = seed

result_chunks.append(
new_op.new_chunk(
Expand Down Expand Up @@ -330,87 +270,32 @@ class GroupBySample(MapReduceOperand, DataFrameOperandMixin):
_op_code_ = opcodes.RAND_SAMPLE
_op_module_ = "dataframe.groupby"

_groupby_params = DictField("groupby_params")
_size = Int64Field("size")
_frac = Float32Field("frac")
_replace = BoolField("replace")
_weights = KeyField("weights")
_seed = Int32Field("seed")
_random_state = RandomStateField("random_state")
_errors = StringField("errors")
groupby_params = DictField("groupby_params", default=None)
size = Int64Field("size", default=None)
frac = Float32Field("frac", default=None)
replace = BoolField("replace", default=None)
weights = KeyField("weights", default=None)
seed = Int32Field("seed", default=None)
_random_state = RandomStateField("random_state", default=None)
errors = StringField("errors", default=None)

# for chunks
# num of instances for chunks
_input_nsplits = NDArrayField("input_nsplits")

def __init__(
self,
groupby_params=None,
size=None,
frac=None,
replace=None,
weights=None,
random_state=None,
seed=None,
errors=None,
input_nsplits=None,
**kw
):
super().__init__(
_groupby_params=groupby_params,
_size=size,
_frac=frac,
_seed=seed,
_replace=replace,
_weights=weights,
_random_state=random_state,
_errors=errors,
_input_nsplits=input_nsplits,
**kw
)

@property
def groupby_params(self):
return self._groupby_params

@property
def size(self):
return self._size
input_nsplits = NDArrayField("input_nsplits", default=None)

@property
def frac(self):
return self._frac

@property
def replace(self):
return self._replace

@property
def weights(self):
return self._weights

@property
def seed(self):
return self._seed
def __init__(self, random_state=None, **kw):
super().__init__(_random_state=random_state, **kw)

@property
def random_state(self):
return self._random_state

@property
def errors(self):
return self._errors

@property
def input_nsplits(self):
return self._input_nsplits

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
input_iter = iter(inputs)
next(input_iter)
if isinstance(self.weights, ENTITY_TYPE):
self._weights = next(input_iter)
self.weights = next(input_iter)

def __call__(self, groupby):
df = groupby
Expand Down Expand Up @@ -481,9 +366,9 @@ def _tile_distributed(cls, op: "GroupBySample", in_df, weights):
for c in sampled_iloc.chunks:
new_op = op.copy().reset_key()
new_op.stage = OperandStage.map
new_op._weights = None
new_op._output_types = [OutputType.tensor]
new_op._input_nsplits = np.array(in_df.nsplits[0])
new_op.weights = None
new_op.output_types = [OutputType.tensor]
new_op.input_nsplits = np.array(in_df.nsplits[0])

map_chunks.append(
new_op.new_chunk(
Expand All @@ -498,13 +383,13 @@ def _tile_distributed(cls, op: "GroupBySample", in_df, weights):
reduce_chunks = []
for ordinal, src_chunk in enumerate(in_df.chunks):
new_op = op.copy().reset_key()
new_op._weights = None
new_op._output_types = [OutputType.tensor]
new_op.weights = None
new_op.output_types = [OutputType.tensor]
new_op.stage = OperandStage.reduce
new_op.reducer_index = (src_chunk.index[0],)
new_op.reducer_ordinal = ordinal
new_op.n_reducers = len(in_df.chunks)
new_op._input_nsplits = np.array(in_df.nsplits[0])
new_op.input_nsplits = np.array(in_df.nsplits[0])

reduce_chunks.append(
new_op.new_chunk(
Expand Down Expand Up @@ -623,12 +508,12 @@ def execute(cls, ctx, op: "GroupBySample"):

def groupby_sample(
groupby,
n=None,
frac=None,
replace=False,
weights=None,
random_state=None,
errors="ignore",
n: Optional[int] = None,
frac: Optional[float] = None,
replace: bool = False,
weights: Union[Sequence, pd.Series, None] = None,
random_state: Optional[np.random.RandomState] = None,
errors: str = "ignore",
):
"""
Return a random sample of items from each group.
Expand Down Expand Up @@ -726,6 +611,9 @@ def groupby_sample(
rs = copy.deepcopy(
random_state.to_numpy() if hasattr(random_state, "to_numpy") else random_state
)
if not isinstance(rs, np.random.RandomState): # pragma: no cover
rs = np.random.RandomState(rs)

op = GroupBySample(
size=n,
frac=frac,
Expand Down
6 changes: 3 additions & 3 deletions mars/serialization/serializables/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _get_field_values(cls, obj: Serializable, fields):
for field in fields:
try:
value = field.get(obj)
if field.on_serialize:
if field.on_serialize is not None:
value = field.on_serialize(value)
except AttributeError:
# Most field values are not None, serialize by list is more efficient than dict.
Expand All @@ -209,14 +209,14 @@ def _set_field_value(obj: Serializable, field: Field, value):
if value is no_default:
return
if type(value) is Placeholder:
if field.on_deserialize:
if field.on_deserialize is not None:
value.callbacks.append(
lambda v: field.set(obj, field.on_deserialize(v))
)
else:
value.callbacks.append(lambda v: field.set(obj, v))
else:
if field.on_deserialize:
if field.on_deserialize is not None:
field.set(obj, field.on_deserialize(value))
else:
field.set(obj, value)
Expand Down
3 changes: 3 additions & 0 deletions mars/tensor/random/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ def _on_serialize_random_state(rs):


def _on_deserialize_random_state(tup):
if tup is None:
return None

rs = np.random.RandomState()
rs.set_state(tup)
return rs
Expand Down
19 changes: 19 additions & 0 deletions mars/tensor/random/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pytest

from ....core import tile
from ....serialization import serialize, deserialize
from ....serialization.serializables import Serializable
from ...datasource import tensor as from_ndarray
from .. import (
beta,
Expand All @@ -31,6 +33,23 @@
shuffle,
RandomState,
)
from ..core import RandomStateField


class ObjWithRandomStateField(Serializable):
random_state = RandomStateField("random_state")


@pytest.mark.parametrize("rs", [None, np.random.RandomState()])
def test_serial_random_state_field(rs):
res = deserialize(*serialize(ObjWithRandomStateField(rs)))
if rs is None:
assert res.random_state is None
else:
original_state = rs.get_state()
new_state = res.random_state.get_state()
assert original_state[0] == new_state[0]
np.testing.assert_array_equal(original_state[1], new_state[1])


def test_random():
Expand Down

0 comments on commit 8ca0d84

Please sign in to comment.