Skip to content

Commit

Permalink
add state-id in group-by to share a state
Browse files Browse the repository at this point in the history
  • Loading branch information
MainRo committed Apr 11, 2022
1 parent dfb5941 commit 87df912
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 10 deletions.
2 changes: 1 addition & 1 deletion rxsci/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .identity import identity
from .last import last
from .map import map
from .multiplex import multiplex, mux_observable, demux_mux_observable
from .multiplex import multiplex, mux_observable, demux_observable, demux_mux_observable
from .on_subscribe import on_subscribe
from .pandas import from_pandas, to_pandas
from .progress import progress
Expand Down
29 changes: 24 additions & 5 deletions rxsci/operators/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from .multiplex import demux_mux_observable


def group_by_mux(key_mapper):
def group_by_mux(key_mapper, state_id):
outer_observer = Subject()

def _group_by(source):
def on_subscribe(observer, scheduler):
state = None
shared_state = None

def on_next(i):
nonlocal state
nonlocal shared_state

if type(i) is rs.OnNextMux:
key = i.key
Expand All @@ -27,6 +29,8 @@ def on_next(i):

elif type(i) is rs.OnCreateMux:
i.store.add_key(state, i.key)
if shared_state:
i.store.add_key(shared_state, i.key)
outer_observer.on_next(i)

elif type(i) is rs.OnCompletedMux:
Expand All @@ -45,12 +49,22 @@ def on_next(i):
outer_observer.on_next(i)

elif type(i) is rs.state.ProbeStateTopology:
state = i.topology.create_mapper(name="groupby")
state = i.topology.create_mapper(
name="groupby",
)
if state_id is not None:
shared_state = i.topology.create_mapper(
name="groupby",
state_id=state_id
)
observer.on_next(i)
outer_observer.on_next(i)
else:
if state is None:
observer.on_error(ValueError("No state configured in group_by operator. A state store operator is probably missing in the graph"))
observer.on_error(ValueError(
"No state configured in group_by operator."
"A state store operator is probably missing in the graph"
))
observer.on_next(i)

return source.subscribe(
Expand All @@ -64,7 +78,7 @@ def on_next(i):
return _group_by, outer_observer


def group_by(key_mapper, pipeline):
def group_by(key_mapper, pipeline, state_id=None):
"""Groups items of according to a key mapper
The source must be a MuxObservable.
Expand All @@ -78,17 +92,22 @@ def group_by(key_mapper, pipeline):
+a-----b--c-|
+1--2-----3-------|
Several instances of the group_by operator can share the same state by
using the same state_id value. This allows some downstream operators to
combine observables by key. This is for example used in merge_asof.
Examples:
>>> rs.ops.group_by(lambda i: i.category, rs.ops.count)
Args:
key_mapper: A function to extract the key from each item
pipeline: The Rx pipe to execute on each group.
state_id: [Optional] The state identifier to use.
Returns:
A MuxObservable with one observable per group.
"""
_group_by, outer_obs = group_by_mux(key_mapper)
_group_by, outer_obs = group_by_mux(key_mapper, state_id)
pipeline = rx.pipe(*pipeline) if type(pipeline) is list else pipeline
return rx.pipe(
_group_by,
Expand Down
23 changes: 20 additions & 3 deletions rxsci/state/state_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,32 @@ def __init__(self):
self.states = []
self.ids = {}

def create_mapper(self, name):
def create_mapper(self, name, state_id=None):
"""A mapper is a non-indexable state. Mappers are used in group_by
operator (where key is mapped to an index). They do not need to be
stored on persistent storage if no other states are used in the
applcation.
"""
return self.create_state(name, data_type='mapper')
return self.create_state(name, data_type='mapper', state_id=state_id)

def create_state(self, name, data_type, default_value=None, state_id=None):
if state_id is not None:
unique_name = '{}-{}'.format(name, state_id)
statedef = StateDef(unique_name, data_type, default_value)
index = 0
for s in self.states:
if s.name == unique_name:
if self.states[index] != statedef:
raise ValueError(
"Cannot share a state with different specs: {} != {}".format(
self.states[index], statedef
))
return index
index += 1

self.states.append(statedef)
return len(self.states) - 1

def create_state(self, name, data_type, default_value=None):
if name in self.ids:
self.ids[name] += 1
else:
Expand Down
75 changes: 74 additions & 1 deletion tests/operators/test_group_by.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import rx
import rx.operators as ops
from rx.subject import Subject
import rxsci as rs
from ..utils import on_probe_state_topology

Expand All @@ -18,7 +19,7 @@ def on_completed():
rx.from_(source).pipe(
rs.state.with_store(
store,
rx.pipe(
rx.pipe(
rs.ops.group_by(
lambda i: i,
rx.pipe(
Expand Down Expand Up @@ -125,3 +126,75 @@ def test_forward_topology_probe():
).subscribe()

assert len(actual_topology_probe) == 1


"""
def test_group_by_shared_store():
source1 = Subject() # [1, 2, 2, 1]
source2 = Subject() # [2, 1, 2, 1]
actual_error = []
actual_completed = []
actual_result = []
mux_actual_result1 = []
mux_actual_result2 = []
def on_completed():
actual_completed.append(True)
store = rs.state.StoreManager(store_factory=rs.state.MemoryStore)
#rx.from_(source1).pipe(
source1.pipe(
rs.state.with_store(
store,
rx.pipe(
rs.ops.group_by(
lambda i: i,
rx.pipe(
ops.do_action(mux_actual_result1.append),
),
state_id='shared_state'
),
),
),
).subscribe()
source2.pipe(
rs.state.with_store(
store,
rx.pipe(
rs.ops.group_by(
lambda i: i,
rx.pipe(
ops.do_action(mux_actual_result2.append),
),
state_id='shared_state'
),
),
),
).subscribe(
on_next=actual_result.append,
on_completed=on_completed,
on_error=actual_error.append,
)
source1.on_next(1)
source1.on_next(2)
source2.on_next(2)
source2.on_next(1)
#assert actual_error == []
#assert actual_completed == [True]
#assert actual_result == source
assert type(mux_actual_result1[0]) is rs.state.ProbeStateTopology
assert mux_actual_result1[1:] == [
rs.OnCreateMux((1 ,(0,)), store),
rs.OnNextMux((1, (0,)), 2, store),
rs.OnCreateMux((0, (0,)), store),
rs.OnNextMux((0, (0,)), 1, store),
rs.OnNextMux((1, (0,)), 2, store),
rs.OnNextMux((0, (0,)), 1, store),
rs.OnCompletedMux((0, (0,)), store),
rs.OnCompletedMux((1, (0,)), store),
]
"""

0 comments on commit 87df912

Please sign in to comment.