diff --git a/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java b/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java index f92eceea8..ae86a30eb 100644 --- a/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java +++ b/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java @@ -3,10 +3,14 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.LinkedList; import java.util.Map; +import java.util.stream.Collectors; + import tech.tablesaw.api.CategoricalColumn; import tech.tablesaw.api.DoubleColumn; import tech.tablesaw.api.NumericColumn; +import tech.tablesaw.columns.Column; import tech.tablesaw.api.Table; import tech.tablesaw.table.TableSlice; import tech.tablesaw.table.TableSliceGroup; @@ -24,78 +28,145 @@ public class PivotTable { /** * Returns a table that is a rotation of the given table pivoted around the key columns, and * filling the output cells using the values calculated by the {@code aggregateFunction} when - * applied to the {@code values column} grouping by the key columns + * applied to the {@code aggregatedColumn} grouping by the key columns + * + * Handles the case whereby there is a single groupingColumn and aggregatedColumn * * @param table The table that provides the data to be pivoted - * @param column1 A "key" categorical column from which the primary grouping is created. There + * @param groupingColumn A "key" categorical column from which the primary grouping is created. There * will be one on each row of the result - * @param column2 A second categorical column for which a subtotal is created; this produces n + * @param pivotColumn A second categorical column for which a subtotal is created; this produces n * columns on each row of the result - * @param values A numeric column that provides the values to be summarized + * @param aggregatedColumn A numeric column that provides the values to be summarized * @param aggregateFunction function that defines what operation is performed on the values in the * subgroups * @return A new, pivoted table */ + + public static Table pivot( + Table table, + CategoricalColumn groupingColumn, + CategoricalColumn pivotColumn, + NumericColumn aggregatedColumns, + AggregateFunction aggregateFunction) { + return pivot(table, List.of(groupingColumn), pivotColumn, List.of(aggregatedColumns), aggregateFunction); + } + + /** + * Returns a table that is a rotation of the given table pivoted around the key columns, and + * filling the output cells using the values calculated by the {@code aggregateFunction} when + * applied to the {@code aggregatedColumns} grouping by the key columns + * + * Handles the case whereby there may be multiple groupingColumns and/or multiple aggregatedColumns + * @param table + * @param groupingColumn + * @param pivotColumn + * @param aggregatedColumns + * @param aggregateFunction + * @return + */ public static Table pivot( Table table, - CategoricalColumn column1, - CategoricalColumn column2, - NumericColumn values, + List> groupingColumns, + CategoricalColumn pivotColumn, + List> aggregatedColumns, AggregateFunction aggregateFunction) { - TableSliceGroup tsg = table.splitOn(column1); + boolean multiAggregated = aggregatedColumns.size() > 1; - Table pivotTable = Table.create("Pivot: " + column1.name() + " x " + column2.name()); - pivotTable.addColumns(column1.type().create(column1.name())); + TableSliceGroup tsg = table.splitOn(groupingColumns.toArray(CategoricalColumn[]::new)); - List valueColumnNames = getValueColumnNames(table, column2); + List groupingColumnNames = groupingColumns.stream().map(_c -> _c.name()).collect(Collectors.toList()); - for (String colName : valueColumnNames) { - pivotTable.addColumns(DoubleColumn.create(colName)); - } + Table pivotTable = Table.create("Pivot: " + String.join(",", groupingColumnNames) + " x " + pivotColumn.name()); + + pivotTable.addColumns(groupingColumns.stream().map(_c -> _c.type().create(_c.name())).toArray(Column[]::new)); - int valueIndex = table.columnIndex(column2); - int keyIndex = table.columnIndex(column1); + List valueColumnNames = getValueColumnNames(table, pivotColumn); - String key; + if(multiAggregated){ + for (String colName : valueColumnNames) + for(NumericColumn aggColumn : aggregatedColumns) { + pivotTable.addColumns(DoubleColumn.create(colName + "." + aggColumn.name())); + } + } + else{ + for (String colName : valueColumnNames) { + pivotTable.addColumns(DoubleColumn.create(colName)); + } + } for (TableSlice slice : tsg.getSlices()) { - key = String.valueOf(slice.get(0, keyIndex)); - pivotTable.column(0).appendCell(key); + + for (int i = 0; i < groupingColumns.size(); i++) { + String key = String.valueOf(slice.get(0, table.columnIndex(groupingColumns.get(i)))); + pivotTable.column(i).appendCell(key); + } Map valueMap = - getValueMap(column1, column2, values, valueIndex, slice, aggregateFunction); + getValueMap(groupingColumns, pivotColumn, aggregatedColumns, slice, aggregateFunction); for (String columnName : valueColumnNames) { - Double aDouble = valueMap.get(columnName); - NumericColumn pivotValueColumn = pivotTable.numberColumn(columnName); - if (aDouble == null) { - pivotValueColumn.appendMissing(); - } else { - pivotValueColumn.appendObj(aDouble); - } - } + for (NumericColumn aggregatedColumn: aggregatedColumns) { + + String appendedColumnName; + + if(multiAggregated){ + appendedColumnName = columnName + "." + aggregatedColumn.name(); + } else { + appendedColumnName = columnName; + } + + NumericColumn pivotValueColumn = pivotTable.numberColumn(appendedColumnName); + + Double aDouble = valueMap.get(appendedColumnName); + + if (aDouble == null) { + pivotValueColumn.appendMissing(); + } else { + pivotValueColumn.appendObj(aDouble); + } + } + } + } return pivotTable; } private static Map getValueMap( - CategoricalColumn column1, - CategoricalColumn column2, - NumericColumn values, - int valueIndex, + List> groupingColumns, + CategoricalColumn pivotColumn, + List> aggregatedColumns, TableSlice slice, AggregateFunction function) { + boolean multiAggregated = aggregatedColumns.size() > 1; Table temp = slice.asTable(); - Table summary = temp.summarize(values.name(), function).by(column1.name(), column2.name()); + List> allKeyColumns = new LinkedList<>(groupingColumns); + allKeyColumns.add(pivotColumn); + + List aggregatedColumnNames = aggregatedColumns.stream().map(NumericColumn::name).collect(Collectors.toList()); + + Table summary = temp.summarize(aggregatedColumnNames, function).by(allKeyColumns.stream().map(CategoricalColumn::name).toArray(String[]::new)); Map valueMap = new HashMap<>(); - NumericColumn nc = summary.numberColumn(summary.columnCount() - 1); - for (int i = 0; i < summary.rowCount(); i++) { - valueMap.put(String.valueOf(summary.get(i, 1)), nc.getDouble(i)); + + if(multiAggregated){ + for (int i = 0; i < summary.rowCount(); i++) { + for (int k = 0; k < aggregatedColumns.size(); k++) { + NumericColumn nc = summary.numberColumn(groupingColumns.size() + k + 1); + valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())) + "." + aggregatedColumns.get(k).name(), nc.getDouble(i)); + } + } } + else{ + NumericColumn nc = summary.numberColumn(summary.columnCount() - 1); + for (int i = 0; i < summary.rowCount(); i++) { + valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())), nc.getDouble(i)); + } + } + return valueMap; } diff --git a/core/src/main/java/tech/tablesaw/api/Table.java b/core/src/main/java/tech/tablesaw/api/Table.java index f099adc09..d9ab35ecf 100644 --- a/core/src/main/java/tech/tablesaw/api/Table.java +++ b/core/src/main/java/tech/tablesaw/api/Table.java @@ -864,35 +864,87 @@ public Table dropWhere(Selection selection) { return newTable; } + /** + * Returns a new column, where the first n columns are the groupingColumns. There are then p additional + * columns, which is the product of each unique value in the pivot column and aggregatedColumn. The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in each of aggregatedColumn, grouped by the values of groupingColumn and pivotColumn. + * + * If more than one aggregatedColumn is provided then each is appended to each unique value of the pivot + * column in the format "{PivotColumnValue}.{AggregatedColumnName} + * + * @param groupingColumn + * @param pivotColumn + * @param aggregatedColumn + * @param aggregateFunction + * @return + */ + public Table pivot( + List> groupingColumns, + CategoricalColumn pivotColumn, + List> aggregatedColumns, + AggregateFunction aggregateFunction) { + return PivotTable.pivot(this, groupingColumns, pivotColumn, aggregatedColumns, aggregateFunction); + } + + /** + * Returns a new column, where the first n columns are the groupingColumns. There are then p additional + * columns, which is the product of each unique value in the pivot column and aggregatedColumn. The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in each of aggregatedColumn, grouped by the values of groupingColumn and pivotColumn. + * + * If more than one aggregatedColumn is provided then each is appended to each unique value of the pivot + * column in the format "{PivotColumnValue}.{AggregatedColumnName} + * + * @param groupingColumnNames + * @param pivotColumnName + * @param aggregatedColumnNames + * @param aggregateFunction + * @return + */ + public Table pivot( + List groupingColumnNames, + String pivotColumnName, + List aggregatedColumnNames, + AggregateFunction aggregateFunction) { + return pivot( + groupingColumnNames.stream().map(this::categoricalColumn).collect(Collectors.toList()), + categoricalColumn(pivotColumnName), + aggregatedColumnNames.stream().map(this::numberColumn).collect(Collectors.toList()), + aggregateFunction); + } + + /** * Returns a pivot on this table, where: The first column contains unique values from the index - * column1 There are n additional columns, one for each unique value in column2 The values in each - * of the cells in these new columns are the result of applying the given AggregateFunction to the - * data in column3, grouped by the values of column1 and column2 + * groupingColumn There are n additional columns, one for each unique value in the pivotColumn. The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in the aggregatedColumn, grouped by the values of groupingColumn and pivotColumn */ public Table pivot( - CategoricalColumn column1, - CategoricalColumn column2, - NumericColumn column3, + CategoricalColumn groupingColumn, + CategoricalColumn pivotColumn, + NumericColumn aggregatedColumn, AggregateFunction aggregateFunction) { - return PivotTable.pivot(this, column1, column2, column3, aggregateFunction); + return PivotTable.pivot(this, groupingColumn, pivotColumn, aggregatedColumn, aggregateFunction); } + /** * Returns a pivot on this table, where: The first column contains unique values from the index - * column1 There are n additional columns, one for each unique value in column2 The values in each - * of the cells in these new columns are the result of applying the given AggregateFunction to the - * data in column3, grouped by the values of column1 and column2 + * groupingColumn There are n additional columns, one for each unique value in the pivotColumn The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in the aggregatedColumn, grouped by the values of groupingColumn and pivotColumn */ public Table pivot( - String column1Name, - String column2Name, - String column3Name, + String groupingColumnName, + String pivotColumnName, + String aggregatedColumnName, AggregateFunction aggregateFunction) { return pivot( - categoricalColumn(column1Name), - categoricalColumn(column2Name), - numberColumn(column3Name), + categoricalColumn(groupingColumnName), + categoricalColumn(pivotColumnName), + numberColumn(pivotColumnName), aggregateFunction); } diff --git a/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java b/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java index c6be7e3e4..037d2676e 100644 --- a/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java +++ b/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java @@ -6,11 +6,16 @@ import org.junit.jupiter.api.Test; import tech.tablesaw.api.Table; import tech.tablesaw.io.csv.CsvReadOptions; +import java.util.List; public class PivotTableTest { + /** + * Illustrate usage of pivot function with a single grouping, pivot and aggregated columns + * @throws Exception + */ @Test - public void pivot() throws Exception { + public void pivotSingle() throws Exception { Table t = Table.read() .csv(CsvReadOptions.builder("../data/bush.csv").missingValueIndicator(":").build()); @@ -30,4 +35,72 @@ public void pivot() throws Exception { assertTrue(pivot.columnNames().contains("2004")); assertEquals(6, pivot.rowCount()); } + + + @Test + public void pivotMultipleGroupAndAggregate() throws Exception { + Table t = + Table.read() + .csv(CsvReadOptions.builder("../data/baseball.csv").build()); + + Table pivot = + t.pivot( + List.of("Team","League"), + "Year", + List.of("RS","RA","W"), + AggregateFunctions.mean); + + assertTrue(pivot.columnNames().contains("Team")); + assertTrue(pivot.columnNames().contains("League")); + assertTrue(pivot.columnNames().contains("2001.RS")); + assertTrue(pivot.columnNames().contains("2001.RA")); + assertTrue(pivot.columnNames().contains("2001.W")); + assertEquals(143, pivot.columnCount()); + assertEquals(40, pivot.rowCount()); + } + + @Test + public void pivotMultipleGroup() throws Exception { + Table t = + Table.read() + .csv(CsvReadOptions.builder("../data/baseball.csv").build()); + + Table pivot = + t.pivot( + List.of("Team","League"), + "Year", + List.of("RS"), + AggregateFunctions.mean); + + assertTrue(pivot.columnNames().contains("Team")); + assertTrue(pivot.columnNames().contains("League")); + assertTrue(pivot.columnNames().contains("2001")); + assertTrue(pivot.columnNames().contains("2002")); + assertTrue(pivot.columnNames().contains("2003")); + assertEquals(49, pivot.columnCount()); + assertEquals(40, pivot.rowCount()); + } + + @Test + public void pivotMultipleAggregate() throws Exception { + Table t = + Table.read() + .csv(CsvReadOptions.builder("../data/baseball.csv").build()); + + Table pivot = + t.pivot( + List.of("League"), + "Year", + List.of("RS","RA","W"), + AggregateFunctions.mean); + + assertTrue(!pivot.columnNames().contains("Team")); + assertTrue(pivot.columnNames().contains("League")); + assertTrue(pivot.columnNames().contains("2001.RS")); + assertTrue(pivot.columnNames().contains("2001.RA")); + assertTrue(pivot.columnNames().contains("2001.W")); + assertEquals(142, pivot.columnCount()); + assertEquals(2, pivot.rowCount()); + } + }