Skip to content

Commit

Permalink
[BEAM-12100][BEAM-10379][BEAM-9514][BEAM-12647][BEAM-12099] Assertion…
Browse files Browse the repository at this point in the history
…Error type mismatch from AggregateScanConverter (apache#15174)

* [BEAM-12100] SUM throws error when overflow/underflow occurs

* [BEAM-10379] Remove filter of nulls in AggregateCombineFnAdapter and create a DropNull wrapper for the aggregations that rely on dropping nulls

* Fix checkstyle issues

* [BEAM-12647] Branch combiner in Aggregations

* [BEAM-12647] Create different paths for aggregations with and without GroupBy

* [BEAM-12647] Override identity in LongSums

* [BEAM-12647] Override identity in CombineFns, add condition to avoid undefined division

* [BEAM-12647] Fix Coder Cast Exceptions

* [BEAM-12647] Fix CountIf coder and result bug

* Fix spotless

* [BEAM-12099] Fix return value in Bit_OR for empty arrays

* [BEAM-12647] Add .withoutDefaults() to GloballyCombineFn for unbounded pcollections

* [BEAM-12647] Change condition in Group GloballyCombine to check if the input has a GlobalWindow

* Fix accesors and comments

* [BEAM-12647] Refactor aggregation combineFn to avoid code duplication

* Change interface name and fix typo

* Branch SUM and SUM0 for null management

* Change AggregateCombiner from interface to abstract class to extend PTransform

* Fix SUM0 message

* Change schemaFn initialize from null to SchemaAggregateFn.create()

* Add javadocs to AggregationCombiner

* Fix Coder error in bitAnd operator

* Fix null handling for bitOr and bitAnd ZetaSQL

* Fix null handling BitXor

* Fix null handling array_agg

* Fix @nullable checker error
  • Loading branch information
benWize authored and calvinleungyk committed Sep 22, 2021
1 parent 589ad13 commit 344a477
Show file tree
Hide file tree
Showing 14 changed files with 634 additions and 118 deletions.
Expand Up @@ -38,6 +38,7 @@
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
Expand Down Expand Up @@ -365,18 +366,77 @@ public PCollection<OutputT> expand(PCollection<InputT> input) {
}

/**
* a {@link PTransform} that does a global combine using an aggregation built up by calls to
* a {@link PTransform} that does a combine using an aggregation built up by calls to
* aggregateField and aggregateFields. The output of this transform will have a schema that is
* determined by the output types of all the composed combiners.
*
* @param <InputT>
*/
public static class CombineFieldsGlobally<InputT>
public abstract static class AggregateCombiner<InputT>
extends PTransform<PCollection<InputT>, PCollection<Row>> {

/**
* Build up an aggregation function over the input elements.
*
* <p>This method specifies an aggregation over single field of the input. The union of all
* calls to aggregateField and aggregateFields will determine the output schema.
*/
public abstract <CombineInputT, AccumT, CombineOutputT>
AggregateCombiner<InputT> aggregateField(
int inputFieldId,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField);

/**
* Build up an aggregation function over the input elements.
*
* <p>This method specifies an aggregation over single field of the input. The union of all
* calls to aggregateField and aggregateFields will determine the output schema.
*/
public abstract <CombineInputT, AccumT, CombineOutputT>
AggregateCombiner<InputT> aggregateField(
String inputFieldName,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField);

/**
* Build up an aggregation function over the input elements by field id.
*
* <p>This method specifies an aggregation over multiple fields of the input. The union of all
* calls to aggregateField and aggregateFields will determine the output schema.
*
* <p>Field types in the output schema will be inferred from the provided combine function.
* Sometimes the field type cannot be inferred due to Java's type erasure. In that case, use the
* overload that allows setting the output field type explicitly.
*/
public abstract <CombineInputT, AccumT, CombineOutputT>
AggregateCombiner<InputT> aggregateFieldsById(
List<Integer> inputFieldIds,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField);
}

/**
* a {@link PTransform} that does a global combine using an aggregation built up by calls to
* aggregateField and aggregateFields. The output of this transform will have a schema that is
* determined by the output types of all the composed combiners.
*/
public static class CombineFieldsGlobally<InputT> extends AggregateCombiner<InputT> {
private final SchemaAggregateFn.Inner schemaAggregateFn;

CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn) {
this.schemaAggregateFn = schemaAggregateFn;
}

/**
* Returns a transform that does a global combine using an aggregation built up by calls to
* aggregateField and aggregateFields. This transform will have an unknown schema that will be
* determined by the output types of all the composed combiners.
*/
public static CombineFieldsGlobally create() {
return new CombineFieldsGlobally<>(SchemaAggregateFn.create());
}

/**
* Build up an aggregation function over the input elements.
*
Expand Down Expand Up @@ -431,6 +491,7 @@ CombineFieldsGlobally<InputT> aggregateFieldBaseValue(
* <p>This method specifies an aggregation over single field of the input. The union of all
* calls to aggregateField and aggregateFields will determine the output schema.
*/
@Override
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField(
String inputFieldName,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Expand All @@ -450,6 +511,7 @@ CombineFieldsGlobally<InputT> aggregateFieldBaseValue(
FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField));
}

@Override
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField(
int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
return new CombineFieldsGlobally<>(
Expand Down Expand Up @@ -526,6 +588,7 @@ public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> agg
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, outputField);
}

@Override
public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsGlobally<InputT> aggregateFieldsById(
List<Integer> inputFieldIds,
Expand All @@ -551,9 +614,13 @@ public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> agg
@Override
public PCollection<Row> expand(PCollection<InputT> input) {
SchemaAggregateFn.Inner fn = schemaAggregateFn.withSchema(input.getSchema());
Combine.Globally<Row, Row> combineFn = Combine.globally(fn);
if (!(input.getWindowingStrategy().getWindowFn() instanceof GlobalWindows)) {
combineFn = combineFn.withoutDefaults();
}
return input
.apply("toRows", Convert.toRows())
.apply("Global Combine", Combine.globally(fn))
.apply("Global Combine", combineFn)
.setRowSchema(fn.getOutputSchema());
}
}
Expand All @@ -566,8 +633,7 @@ public PCollection<Row> expand(PCollection<InputT> input) {
* specified extracted fields.
*/
@AutoValue
public abstract static class ByFields<InputT>
extends PTransform<PCollection<InputT>, PCollection<Row>> {
public abstract static class ByFields<InputT> extends AggregateCombiner<InputT> {
abstract FieldAccessDescriptor getFieldAccessDescriptor();

abstract String getKeyField();
Expand Down Expand Up @@ -698,6 +764,7 @@ CombineFieldsByFields<InputT> aggregateFieldBaseValue(
* <p>This method specifies an aggregation over single field of the input. The union of all
* calls to aggregateField and aggregateFields will determine the output schema.
*/
@Override
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
String inputFieldName,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Expand Down Expand Up @@ -725,6 +792,7 @@ CombineFieldsByFields<InputT> aggregateFieldBaseValue(
getValueField());
}

@Override
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
return CombineFieldsByFields.of(
Expand Down Expand Up @@ -812,6 +880,7 @@ public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> agg
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, outputField);
}

@Override
public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsByFields<InputT> aggregateFieldsById(
List<Integer> inputFieldIds,
Expand Down Expand Up @@ -875,8 +944,7 @@ public void process(@Element KV<Row, Iterable<Row>> e, OutputReceiver<Row> o) {
* determined by the output types of all the composed combiners.
*/
@AutoValue
public abstract static class CombineFieldsByFields<InputT>
extends PTransform<PCollection<InputT>, PCollection<Row>> {
public abstract static class CombineFieldsByFields<InputT> extends AggregateCombiner<InputT> {
abstract ByFields<InputT> getByFields();

abstract SchemaAggregateFn.Inner getSchemaAggregateFn();
Expand Down Expand Up @@ -995,6 +1063,7 @@ CombineFieldsByFields<InputT> aggregateFieldBaseValue(
* <p>This method specifies an aggregation over single field of the input. The union of all
* calls to aggregateField and aggregateFields will determine the output schema.
*/
@Override
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
String inputFieldName,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Expand All @@ -1020,6 +1089,7 @@ CombineFieldsByFields<InputT> aggregateFieldBaseValue(
.build();
}

@Override
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
return toBuilder()
Expand Down Expand Up @@ -1095,6 +1165,7 @@ public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> agg
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, outputField);
}

@Override
public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsByFields<InputT> aggregateFieldsById(
List<Integer> inputFieldIds,
Expand Down
Expand Up @@ -31,6 +31,7 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.transforms.Group;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
Expand Down Expand Up @@ -216,6 +217,8 @@ private static class Transform extends PTransform<PCollectionList<Row>, PCollect
private WindowFn<Row, IntervalWindow> windowFn;
private int windowFieldIndex;
private List<FieldAggregation> fieldAggregations;
private final int groupSetCount;
private boolean ignoreValues;

private Transform(
WindowFn<Row, IntervalWindow> windowFn,
Expand All @@ -227,6 +230,8 @@ private Transform(
this.windowFieldIndex = windowFieldIndex;
this.fieldAggregations = fieldAggregations;
this.outputSchema = outputSchema;
this.groupSetCount = groupSet.asList().size();
this.ignoreValues = false;
this.keyFieldsIds =
groupSet.asList().stream().filter(i -> i != windowFieldIndex).collect(toList());
}
Expand All @@ -243,55 +248,68 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
if (windowFn != null) {
windowedStream = assignTimestampsAndWindow(upstream);
}

validateWindowIsSupported(windowedStream);
// Check if have fields to be grouped
if (groupSetCount > 0) {
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> byFields =
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
PTransform<PCollection<Row>, PCollection<Row>> combiner = createCombiner(byFields);
boolean verifyRowValues =
pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class).getVerifyRowValues();
return windowedStream
.apply(combiner)
.apply(
"mergeRecord",
ParDo.of(
mergeRecord(outputSchema, windowFieldIndex, ignoreValues, verifyRowValues)))
.setRowSchema(outputSchema);
}
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> globally =
org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsGlobally.create();
PTransform<PCollection<Row>, PCollection<Row>> combiner = createCombiner(globally);
return windowedStream.apply(combiner).setRowSchema(outputSchema);
}

private PTransform<PCollection<Row>, PCollection<Row>> createCombiner(
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> initialCombiner) {

org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> byFields =
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsByFields<Row> combined = null;
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner combined = null;
for (FieldAggregation fieldAggregation : fieldAggregations) {
List<Integer> inputs = fieldAggregation.inputs;
CombineFn combineFn = fieldAggregation.combineFn;
if (inputs.size() > 1 || inputs.isEmpty()) {
// In this path we extract a Row (an empty row if inputs.isEmpty).
if (inputs.size() == 1) {
// Combining over a single field, so extract just that field.
combined =
(combined == null)
? byFields.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField)
: combined.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField);
? initialCombiner.aggregateField(
inputs.get(0), combineFn, fieldAggregation.outputField)
: combined.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField);
} else {
// Combining over a single field, so extract just that field.
// In this path we extract a Row (an empty row if inputs.isEmpty).
combined =
(combined == null)
? byFields.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField)
: combined.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField);
? initialCombiner.aggregateFieldsById(
inputs, combineFn, fieldAggregation.outputField)
: combined.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField);
}
}

PTransform<PCollection<Row>, PCollection<Row>> combiner = combined;
boolean ignoreValues = false;
if (combiner == null) {
// If no field aggregations were specified, we run a constant combiner that always returns
// a single empty row for each key. This is used by the SELECT DISTINCT query plan - in this
// case a group by is generated to determine unique keys, and a constant null combiner is
// used.
combiner =
byFields.aggregateField(
initialCombiner.aggregateField(
"*",
AggregationCombineFnAdapter.createConstantCombineFn(),
Field.of(
"e",
FieldType.row(AggregationCombineFnAdapter.EMPTY_SCHEMA).withNullable(true)));
ignoreValues = true;
}

boolean verifyRowValues =
pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class).getVerifyRowValues();
return windowedStream
.apply(combiner)
.apply(
"mergeRecord",
ParDo.of(mergeRecord(outputSchema, windowFieldIndex, ignoreValues, verifyRowValues)))
.setRowSchema(outputSchema);
return combiner;
}

/** Extract timestamps from the windowFieldIndex, then window into windowFns. */
Expand Down Expand Up @@ -349,7 +367,6 @@ public void processElement(
if (!ignoreValues) {
fieldValues.addAll(kvRow.getRow(1).getValues());
}

if (windowStartFieldIndex != -1) {
fieldValues.add(windowStartFieldIndex, ((IntervalWindow) window).start());
}
Expand Down

0 comments on commit 344a477

Please sign in to comment.