Skip to content

Commit

Permalink
[SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

In the PR, I propose to extend implementation of existing method:
```
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset
```
to support values of the struct type. This allows pivoting by multiple columns combined by `struct`:
```
trainingSales
      .groupBy($"sales.year")
      .pivot(
        pivotColumn = struct(lower($"sales.course"), $"training"),
        values = Seq(
          struct(lit("dotnet"), lit("Experts")),
          struct(lit("java"), lit("Dummies")))
      ).agg(sum($"sales.earnings"))
```

## How was this patch tested?

Added a test for values specified via `struct` in Java and Scala.

Closes apache#22316 from MaxGekk/pivoting-by-multiple-columns2.

Lead-authored-by: Maxim Gekk <maxim.gekk@databricks.com>
Co-authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
  • Loading branch information
2 people authored and jackylee-ch committed Feb 18, 2019
1 parent 2f04982 commit 04cbe0a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by
* multiple columns, use the `struct` function to combine the columns and values:
*
* {{{
* df.groupBy("year")
* .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
* .agg(sum($"earnings"))
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
Expand Down Expand Up @@ -413,10 +422,14 @@ class RelationalGroupedDataset protected[sql](
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
groupType match {
case RelationalGroupedDataset.GroupByType =>
val valueExprs = values.map(_ match {
case c: Column => c.expr
case v => Literal.apply(v)
})
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs))
case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
Expand Down Expand Up @@ -561,5 +574,5 @@ private[sql] object RelationalGroupedDataset {
/**
* To indicate it's the PIVOT
*/
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType
}
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,22 @@ public void pivot() {
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}

@Test
public void pivotColumnValues() {
Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
.pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
.agg(sum("earnings")).orderBy("year").collectAsList();

Assert.assertEquals(2012, actual.get(0).getInt(0));
Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);

Assert.assertEquals(2013, actual.get(1).getInt(0));
Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}

private String getResource(String resource) {
try {
// The following "getResource" has different behaviors in SBT and Maven.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {

assert(exception.getMessage.contains("aggregate functions are not allowed"))
}

test("pivoting column list with values") {
val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(struct(lower($"sales.course"), $"training"), Seq(
struct(lit("dotnet"), lit("Experts")),
struct(lit("java"), lit("Dummies")))
).agg(sum($"sales.earnings"))

checkAnswer(df, expected)
}

test("pivoting column list") {
val exception = intercept[RuntimeException] {
trainingSales
.groupBy($"sales.year")
.pivot(struct(lower($"sales.course"), $"training"))
.agg(sum($"sales.earnings"))
.collect()
}
assert(exception.getMessage.contains("Unsupported literal type"))
}
}

0 comments on commit 04cbe0a

Please sign in to comment.