Skip to content

Commit

Permalink
[FLINK-5400] [core] Add accessor to folding states in RuntimeContext
Browse files Browse the repository at this point in the history
This closes apache#3053
  • Loading branch information
xiaogang.sxg authored and joseprupi committed Feb 12, 2017
1 parent 9d0c48b commit 854077e
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 10 deletions.
Expand Up @@ -27,6 +27,8 @@
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
Expand Down Expand Up @@ -307,7 +309,7 @@ public interface RuntimeContext {
<T> ListState<T> getListState(ListStateDescriptor<T> stateProperties);

/**
* Gets a handle to the system's key/value list state. This state is similar to the state
* Gets a handle to the system's key/value reducing state. This state is similar to the state
* accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
* aggregates values.
*
Expand All @@ -319,16 +321,16 @@ public interface RuntimeContext {
*
* keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
*
* private ReducingState<Long> sum;
* private ReducingState<Long> state;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getReducingState(
* new ReducingStateDescriptor<>("sum", MyType.class, 0L, (a, b) -> a + b));
* new ReducingStateDescriptor<>("sum", (a, b) -> a + b, Long.class));
* }
*
* public Tuple2<MyType, Long> map(MyType value) {
* sum.add(value.count());
* return new Tuple2<>(value, sum.get());
* state.add(value.count());
* return new Tuple2<>(value, state.get());
* }
* });
*
Expand All @@ -345,4 +347,44 @@ public interface RuntimeContext {
*/
@PublicEvolving
<T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties);

/**
* Gets a handle to the system's key/value folding state. This state is similar to the state
* accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
* aggregates values with different types.
*
* <p>This state is only accessible if the function is executed on a KeyedStream.
*
* <pre>{@code
* DataStream<MyType> stream = ...;
* KeyedStream<MyType> keyedStream = stream.keyBy("id");
*
* keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
*
* private FoldingState<MyType, Long> state;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getReducingState(
* new FoldingStateDescriptor<>("sum", 0L, (a, b) -> a.count() + b, Long.class));
* }
*
* public Tuple2<MyType, Long> map(MyType value) {
* state.add(value);
* return new Tuple2<>(value, state.get());
* }
* });
*
* }</pre>
*
* @param stateProperties The descriptor defining the properties of the stats.
*
* @param <T> The type of value stored in the state.
*
* @return The partitioned state object.
*
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part of a KeyedStream).
*/
@PublicEvolving
<T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties);
}
Expand Up @@ -29,6 +29,8 @@
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
Expand Down Expand Up @@ -205,4 +207,11 @@ public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> statePro
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}

@Override
@PublicEvolving
public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) {
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}
}
Expand Up @@ -118,7 +118,7 @@ public interface KeyedStateStore {
<T> ListState<T> getListState(ListStateDescriptor<T> stateProperties);

/**
* Gets a handle to the system's key/value list state. This state is similar to the state
* Gets a handle to the system's key/value reducing state. This state is similar to the state
* accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
* aggregates values.
*
Expand All @@ -130,16 +130,16 @@ public interface KeyedStateStore {
*
* keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
*
* private ReducingState<Long> sum;
* private ReducingState<Long> state;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getReducingState(
* new ReducingStateDescriptor<>("sum", MyType.class, 0L, (a, b) -> a + b));
* new ReducingStateDescriptor<>("sum", (a, b) -> a + b, Long.class));
* }
*
* public Tuple2<MyType, Long> map(MyType value) {
* sum.add(value.count());
* return new Tuple2<>(value, sum.get());
* state.add(value.count());
* return new Tuple2<>(value, state.get());
* }
* });
*
Expand All @@ -156,4 +156,44 @@ public interface KeyedStateStore {
*/
@PublicEvolving
<T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties);

/**
* Gets a handle to the system's key/value folding state. This state is similar to the state
* accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
* aggregates values with different types.
*
* <p>This state is only accessible if the function is executed on a KeyedStream.
*
* <pre>{@code
* DataStream<MyType> stream = ...;
* KeyedStream<MyType> keyedStream = stream.keyBy("id");
*
* keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
*
* private FoldingState<MyType, Long> state;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getReducingState(
* new FoldingStateDescriptor<>("sum", 0L, (a, b) -> a.count() + b, Long.class));
* }
*
* public Tuple2<MyType, Long> map(MyType value) {
* state.add(value);
* return new Tuple2<>(value, state.get());
* }
* });
*
* }</pre>
*
* @param stateProperties The descriptor defining the properties of the stats.
*
* @param <T> The type of value stored in the state.
*
* @return The partitioned state object.
*
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part of a KeyedStream).
*/
@PublicEvolving
<T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties);
}
Expand Up @@ -20,6 +20,8 @@

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
Expand Down Expand Up @@ -80,6 +82,17 @@ public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> statePro
}
}

@Override
public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) {
requireNonNull(stateProperties, "The state properties must not be null");
try {
stateProperties.initializeSerializerUnlessSet(executionConfig);
return getPartitionedState(stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
}

private <S extends State> S getPartitionedState(StateDescriptor<S, ?> stateDescriptor) throws Exception {
return keyedStateBackend.getPartitionedState(
VoidNamespace.INSTANCE,
Expand Down
Expand Up @@ -32,6 +32,8 @@
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
Expand Down Expand Up @@ -164,6 +166,12 @@ public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> statePro
throw new UnsupportedOperationException("State is not supported in rich async functions.");
}

@Override
public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) {
throw new UnsupportedOperationException("State is not supported in rich async functions.");
}


@Override
public <V, A extends Serializable> void addAccumulator(String name, Accumulator<V, A> accumulator) {
throw new UnsupportedOperationException("Accumulators are not supported in rich async functions.");
Expand Down
Expand Up @@ -22,6 +22,8 @@
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext;
import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
Expand Down Expand Up @@ -128,6 +130,13 @@ public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> statePro
return keyedStateStore.getReducingState(stateProperties);
}

@Override
public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) {
KeyedStateStore keyedStateStore = checkPreconditionsAndGetKeyedStateStore(stateProperties);
stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
return keyedStateStore.getFoldingState(stateProperties);
}

private KeyedStateStore checkPreconditionsAndGetKeyedStateStore(StateDescriptor<?, ?> stateDescriptor) {
Preconditions.checkNotNull(stateDescriptor, "The state properties must not be null");
KeyedStateStore keyedStateStore = operator.getKeyedStateStore();
Expand Down
Expand Up @@ -21,9 +21,11 @@
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
Expand Down Expand Up @@ -164,6 +166,17 @@ public Integer reduce(Integer value1, Integer value2) throws Exception {
// expected
}

try {
runtimeContext.getFoldingState(new FoldingStateDescriptor<>("foobar", 0, new FoldFunction<Integer, Integer>() {
@Override
public Integer fold(Integer accumulator, Integer value) throws Exception {
return accumulator;
}
}, Integer.class));
} catch (UnsupportedOperationException e) {
// expected
}

try {
runtimeContext.addAccumulator("foobar", new Accumulator<Integer, Integer>() {
private static final long serialVersionUID = -4673320336846482358L;
Expand Down
Expand Up @@ -22,7 +22,9 @@
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
Expand Down Expand Up @@ -115,6 +117,35 @@ public void testReducingStateInstantiation() throws Exception {
assertTrue(((KryoSerializer<?>) serializer).getKryo().getRegistration(Path.class).getId() > 0);
}

@Test
public void testFoldingStateInstantiation() throws Exception {

final ExecutionConfig config = new ExecutionConfig();
config.registerKryoType(Path.class);

final AtomicReference<Object> descriptorCapture = new AtomicReference<>();

StreamingRuntimeContext context = new StreamingRuntimeContext(
createDescriptorCapturingMockOp(descriptorCapture, config),
createMockEnvironment(),
Collections.<String, Accumulator<?, ?>>emptyMap());

@SuppressWarnings("unchecked")
FoldFunction<String, TaskInfo> folder = (FoldFunction<String, TaskInfo>) mock(FoldFunction.class);

FoldingStateDescriptor<String, TaskInfo> descr =
new FoldingStateDescriptor<>("name", null, folder, TaskInfo.class);

context.getFoldingState(descr);

FoldingStateDescriptor<?, ?> descrIntercepted = (FoldingStateDescriptor<?, ?>) descriptorCapture.get();
TypeSerializer<?> serializer = descrIntercepted.getSerializer();

// check that the Path class is really registered, i.e., the execution config was applied
assertTrue(serializer instanceof KryoSerializer);
assertTrue(((KryoSerializer<?>) serializer).getKryo().getRegistration(Path.class).getId() > 0);
}

@Test
public void testListStateInstantiation() throws Exception {

Expand Down

0 comments on commit 854077e

Please sign in to comment.