Skip to content

Commit

Permalink
[FLINK-4563] [metrics] scope caching not adjusted for multiple reporters
Browse files Browse the repository at this point in the history
add tests. attempt filtering rule in org.apache.flink.api.table.plan.rules.dataSet.DataSetAggregateRule.scala:39
  • Loading branch information
Anton Mushin committed Oct 13, 2016
1 parent 8f802db commit 470f19d
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 29 deletions.
47 changes: 45 additions & 2 deletions docs/dev/table_api.md
Expand Up @@ -973,7 +973,7 @@ dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOLEAN" |

as = composite , ".as(" , fieldReference , ")" ;

aggregation = composite , ( ".sum" | ".min" | ".max" | ".count" | ".avg" ) , [ "()" ] ;
aggregation = composite , ( ".sum" | ".min" | ".max" | ".count" | ".avg" | ".stddev_pop" | ".stddev_samp" | ".var_pop" | ".var_samp" ) , [ "()" ] ;

if = composite , ".?(" , expression , "," , expression , ")" ;

Expand Down Expand Up @@ -3070,9 +3070,52 @@ MIN(value)
<p>Returns the minimum value of <i>value</i> across all input values.</p>
</td>
</tr>

<tr>
<td>
{% highlight text %}
STDDEV_POP(value)
{% endhighlight %}
</td>
<td>
<p>Returns the standard deviation of numeric <i>value</i></p>
</td>
</tr>

<tr>
<td>
{% highlight text %}
STDDEV_SAMP(value)
{% endhighlight %}
</td>
<td>
<p>Returns the sample standard deviation of numeric <i>value</i></p>
</td>
</tr>

<tr>
<td>
{% highlight text %}
VAR_POP(value)
{% endhighlight %}
</td>
<td>
<p>Returns the variance of numeric <i>value</i></p>
</td>
</tr>

<tr>
<td>
{% highlight text %}
VAR_SAMP (value)
{% endhighlight %}
</td>
<td>
<p>Returns the sample variance of numeric <i>value</i></p>
</td>
</tr>
</tbody>
</table>

</div>
</div>

Expand Down
Expand Up @@ -268,7 +268,7 @@ abstract class BatchTableEnvironment(
case a: AssertionError =>
throw a.getCause
}
print(s"\n${RelOptUtil.toString(dataSetPlan)}\n${RelOptUtil.toString(relNode)}")
print(s"${RelOptUtil.toString(dataSetPlan)}\n")
dataSetPlan match {
case node: DataSetRel =>
node.translateToPlan(
Expand All @@ -277,8 +277,6 @@ abstract class BatchTableEnvironment(
).asInstanceOf[DataSet[A]]
case _ => ???
}
// val c:DataSet[A] = ???
// c
}

}
Expand Up @@ -23,6 +23,7 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.sql.SqlKind
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.runtime.aggregate.AggregateUtil
Expand Down Expand Up @@ -79,7 +80,21 @@ class DataSetAggregate(
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
val aggCnt = this.namedAggregates.size
planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize)
var resultCost = planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize)
// this.namedAggregates.foreach(x => {
// x.getKey.getAggregation.getKind match {
// case SqlKind.STDDEV_POP =>
// resultCost = resultCost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize))
// case SqlKind.STDDEV_SAMP =>
// resultCost = resultCost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize))
// case SqlKind.VAR_SAMP =>
// resultCost = resultCost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize))
// case SqlKind.VAR_POP =>
// resultCost = resultCost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize))
// case default => None
// }
// })
resultCost
}

override def translateToPlan(
Expand Down
Expand Up @@ -106,8 +106,7 @@ object FlinkRuleSets {
DataSetMinusRule.INSTANCE,
DataSetSortRule.INSTANCE,
DataSetValuesRule.INSTANCE,
BatchTableSourceScanRule.INSTANCE,
AggregateReduceFunctionsRule.INSTANCE
BatchTableSourceScanRule.INSTANCE
)

/**
Expand Down
Expand Up @@ -22,6 +22,8 @@ import org.apache.calcite.plan.{RelOptRuleCall, Convention, RelOptRule, RelTrait
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.logical.LogicalAggregate
import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule
import org.apache.calcite.sql.SqlKind
import org.apache.flink.api.table.TableException
import org.apache.flink.api.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
import scala.collection.JavaConversions._
Expand Down Expand Up @@ -49,7 +51,14 @@ class DataSetAggregateRule
throw new TableException("GROUPING SETS are currently not supported.")
}

!distinctAggs && !groupSets && !agg.indicator
val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall {
case SqlKind.SUM => true
case SqlKind.MIN => true
case SqlKind.MAX => true
case _ => false
}

!distinctAggs && !groupSets && !agg.indicator && supported
}

override def convert(rel: RelNode): RelNode = {
Expand Down
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.api.java.batch.sql;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.table.BatchTableEnvironment;
Expand All @@ -31,7 +32,11 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.collection.mutable.Buffer;

import java.util.Arrays;
import java.util.List;

@RunWith(Parameterized.class)
Expand Down Expand Up @@ -118,4 +123,48 @@ public void testJoin() throws Exception {
String expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n";
compareResultAsText(results, expected);
}

@Test
public void testDeviationAggregation() throws Exception {



ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());

DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
tableEnv.registerDataSet("AggTable", ds, "x, y, z");

Buffer<String> columnForAgg = JavaConversions.asScalaBuffer(Arrays.asList("x, y".split(",")));

String sqlQuery = getSelectQuery("AVG(?),STDDEV_POP(?),STDDEV_SAMP(?),VAR_POP(?),VAR_SAMP(?)", columnForAgg, "AggTable");
Table result = tableEnv.sql(sqlQuery);

String sqlQuery1 = getSelectQuery("SUM(?)/COUNT(?), " +
"SQRT( (SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?)), " +
"SQRT( (SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END), " +
"(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?), " +
"(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END", columnForAgg, "AggTable");

Table result1 = tableEnv.sql(sqlQuery1);

DataSet<Row> resultSet = tableEnv.toDataSet(result, Row.class);
List<Row> results = resultSet.collect();

DataSet<Row> expectedResultSet = tableEnv.toDataSet(result1, Row.class);
String expectedResults = expectedResultSet.map(new MapFunction<Row, Object>() {
@Override
public Object map(Row value) throws Exception {
StringBuilder stringBuffer = new StringBuilder();
Iterator<Object> productIterator = value.productIterator();
while (productIterator.hasNext()) {
Object product = productIterator.next();
stringBuffer.append(Double.valueOf(product.toString()).intValue()).append(",");
}
return stringBuffer.substring(0, stringBuffer.length() - 1);
}
}).collect().toString().replaceAll("\\[|\\]", "");

compareResultAsText(results, expectedResults);
}
}
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.api.scala.batch.sql

import java.util
import java.util.Comparator

import org.apache.flink.api.scala._
import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
Expand All @@ -30,6 +31,7 @@ import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.TestBaseUtils
import org.apache.log4j.AppenderSkeleton
import org.apache.log4j.spi.LoggingEvent
import org.junit.Assert._
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
Expand Down Expand Up @@ -265,36 +267,109 @@ class AggregationsITCase(

@Test
def testStddevPopAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val ds = env.fromElements(
(1: Byte, 1 : Short, 1, 1L, 1F, 1D),
(2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv)
tEnv.registerTable("myTable", ds)
val columns = Array("_1","_2","_3","_4","_5","_6")

val sqlQuery = getSelectQuery("STDDEV_POP(?)")(columns,"myTable")
// val sqlExpectedQuery = getSelectQuery("SQRT((SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?))")(columns,"myTable")

val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect()
//val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect().toString.replaceAll("Buffer\\(|\\)", "")
val expectedResult = "0,0,0,0,0.5,0.5"
TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult)
}

@Test
def testStddevPopAggregateWithOtherAggreagte(): Unit = {
val localconf = config
localconf.setNullCheck(true)
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, localconf)

val ds = env.fromElements(
(1: Byte, 1 : Short, 1, 1L, 1F, 1D),
(2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv)
tEnv.registerTable("myTable", ds)
val columns = Array("_6")//,"_2","_3","_4","_5","_6")

val sqlQuery = getSelectQuery("STDDEV_POP(?), sum(?)")(columns,"myTable")//, avg(?), sum(?), max(?), min(?), count(?)")

// val sqlExpectedQuery = getSelectQuery("SQRT((SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?))," +
// "avg(?),sum(?),max(?),min(?),count(?)")(columns,"myTable")

val sqlExpectedQuery = getSelectQuery("SQRT((SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?))," +
"sum(?),avg(?)")(columns,"myTable")

val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect()
val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect().toString.replaceAll("Buffer\\(|\\)", "")
//val expectedResult = "0.0,1,3,2,1,2,0.0,1,3,2,1,2,0.0,1,3,2,1,2,0.0,1,3,2,1,2,0.5,1.5,3.0,2.0,1.0,2,0.5,1.5,3.0,2.0,1.0,2"
TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult)
}

@Test
def testStddevSampAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

// val sqlQuery = "SELECT STDDEV_POP(_1), STDDEV_POP(_2), STDDEV_POP(_3), STDDEV_POP(_4), STDDEV_POP(_5), " +
// "STDDEV_POP(_6),STDDEV_SAMP(_6),VAR_POP(_6)" +
// "FROM MyTable"
val sqlQuery ="SELECT stddev_pop(_1),STDDEV_SAMP(_1),VAR_SAMP(_1),VAR_POP(_1) FROM MyTable"
val ds1 = env.fromElements(
(1: Byte, 1 : Short, 1, 1L, 1F, 1D),
(2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv)
tEnv.registerTable("myTable", ds1)
val columns = Seq("_1","_2","_3","_4","_5","_6")

// val sqlQuery2 = "SELECT " +
// "SQRT((SUM(a*a) - SUM(a)*SUM(a)/COUNT(a))/COUNT(a)) "+
// "from (select _1 as a from MyTable)"
val sqlQuery = getSelectQuery("STDDEV_SAMP(?)")(columns,"myTable")
//val sqlExpectedQuery = getSelectQuery("SQRT((SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END)")(columns,"myTable")

val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect()
//val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect().toString.replaceAll("Buffer\\(|\\)", "")
//TODO
val expectedResult = "1,1,1,1,0.70710677,0.7071067811865476"
TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult)
}

@Test
def testVarPopAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val ds = env.fromElements(
(1: Byte, 1 : Short, 1, 1L, 1.0f, 1.0d),
(2: Byte, 2 : Short, 2, 2L, 2.0f, 2.0d)).toTable(tEnv)
tEnv.registerTable("MyTable", ds)
(1: Byte, 1 : Short, 1, 1L, 1F, 1D),
(2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv)
tEnv.registerTable("myTable", ds)
val columns = Seq("_1","_2","_3","_4","_5","_6")

val values = List(1.0,2.0)
val expectedVal = Math.sqrt((values.map(x => x * x).sum - values.sum * values.sum / values.size) / values.size)
print(s"expected STDDEV_POP = $expectedVal")
val sqlQuery = getSelectQuery("var_pop(?)")(columns,"myTable")
val sqlExpectedQuery = getSelectQuery("(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?)")(columns,"myTable")

// tEnv.sql(sqlQuery2).toDataSet[Row].collect().foreach(v=>print(v+" "))
val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect()
val table: Table = tEnv.sql(sqlQuery)
print("\n====\n"+tEnv.explain(table)+"\n====\n")
val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect().toString.replaceAll("Buffer\\(|\\)", "")
TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult)
}

@Test
def testVarSampAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val ds = env.fromElements(
(1: Byte, 1 : Short, 1, 1L, 1F, 1D),
(2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv)
tEnv.registerTable("myTable", ds)
val columns = Seq("_1","_2","_3","_4","_5","_6")

val sqlQuery = getSelectQuery("var_samp(?)")(columns,"myTable")
val sqlExpectedQuery = getSelectQuery("(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END")(columns,"myTable")

val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect()
val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect().toString.replaceAll("Buffer\\(|\\)", "")
TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult)

val expectedResult = s"${Math.round(expectedVal)},${Math.round(expectedVal)},${Math.round(expectedVal)},${Math.round(expectedVal)},$expectedVal,$expectedVal"
TestBaseUtils.compareResultAsText(actualResult.asJava, expectedResult )
}

}

0 comments on commit 470f19d

Please sign in to comment.