diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d700fb83b9b70..dbacdbff7383a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -330,6 +330,15 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy("year").pivot("course").sum("earnings") * }}} * + * From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by + * multiple columns, use the `struct` function to combine the columns and values: + * + * {{{ + * df.groupBy("year") + * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) + * .agg(sum($"earnings")) + * }}} + * * @param pivotColumn Name of the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 @@ -413,10 +422,14 @@ class RelationalGroupedDataset protected[sql]( def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => + val valueExprs = values.map(_ match { + case c: Column => c.expr + case v => Literal.apply(v) + }) new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs)) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -561,5 +574,5 @@ private[sql] object RelationalGroupedDataset { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 3f37e5814ccaa..00f41d6484afb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -317,6 +317,22 @@ public void pivot() { Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); } + @Test + public void pivotColumnValues() { + Dataset df = spark.table("courseSales"); + List actual = df.groupBy("year") + .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java"))) + .agg(sum("earnings")).orderBy("year").collectAsList(); + + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); + + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); + } + private String getResource(String resource) { try { // The following "getResource" has different behaviors in SBT and Maven. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b972b9ef93e5e..02ab19754b0c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { assert(exception.getMessage.contains("aggregate functions are not allowed")) } + + test("pivoting column list with values") { + val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training"), Seq( + struct(lit("dotnet"), lit("Experts")), + struct(lit("java"), lit("Dummies"))) + ).agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("pivoting column list") { + val exception = intercept[RuntimeException] { + trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training")) + .agg(sum($"sales.earnings")) + .collect() + } + assert(exception.getMessage.contains("Unsupported literal type")) + } }