Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple grouped / aggregated columns in pivoting #1145

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
143 changes: 107 additions & 36 deletions core/src/main/java/tech/tablesaw/aggregate/PivotTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.LinkedList;
import java.util.Map;
import java.util.stream.Collectors;

import tech.tablesaw.api.CategoricalColumn;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.columns.Column;
import tech.tablesaw.api.Table;
import tech.tablesaw.table.TableSlice;
import tech.tablesaw.table.TableSliceGroup;
Expand All @@ -24,78 +28,145 @@ public class PivotTable {
/**
* Returns a table that is a rotation of the given table pivoted around the key columns, and
* filling the output cells using the values calculated by the {@code aggregateFunction} when
* applied to the {@code values column} grouping by the key columns
* applied to the {@code aggregatedColumn} grouping by the key columns
*
* Handles the case whereby there is a single groupingColumn and aggregatedColumn
*
* @param table The table that provides the data to be pivoted
* @param column1 A "key" categorical column from which the primary grouping is created. There
* @param groupingColumn A "key" categorical column from which the primary grouping is created. There
* will be one on each row of the result
* @param column2 A second categorical column for which a subtotal is created; this produces n
* @param pivotColumn A second categorical column for which a subtotal is created; this produces n
* columns on each row of the result
* @param values A numeric column that provides the values to be summarized
* @param aggregatedColumn A numeric column that provides the values to be summarized
* @param aggregateFunction function that defines what operation is performed on the values in the
* subgroups
* @return A new, pivoted table
*/

public static Table pivot(
Table table,
CategoricalColumn<?> groupingColumn,
CategoricalColumn<?> pivotColumn,
NumericColumn<?> aggregatedColumns,
AggregateFunction<?, ?> aggregateFunction) {
return pivot(table, List.of(groupingColumn), pivotColumn, List.of(aggregatedColumns), aggregateFunction);
}

/**
* Returns a table that is a rotation of the given table pivoted around the key columns, and
* filling the output cells using the values calculated by the {@code aggregateFunction} when
* applied to the {@code aggregatedColumns} grouping by the key columns
*
* Handles the case whereby there may be multiple groupingColumns and/or multiple aggregatedColumns
* @param table
* @param groupingColumn
* @param pivotColumn
* @param aggregatedColumns
* @param aggregateFunction
* @return
*/
public static Table pivot(
Table table,
CategoricalColumn<?> column1,
CategoricalColumn<?> column2,
NumericColumn<?> values,
List<CategoricalColumn<?>> groupingColumns,
CategoricalColumn<?> pivotColumn,
List<NumericColumn<?>> aggregatedColumns,
AggregateFunction<?, ?> aggregateFunction) {

TableSliceGroup tsg = table.splitOn(column1);
boolean multiAggregated = aggregatedColumns.size() > 1;

Table pivotTable = Table.create("Pivot: " + column1.name() + " x " + column2.name());
pivotTable.addColumns(column1.type().create(column1.name()));
TableSliceGroup tsg = table.splitOn(groupingColumns.toArray(CategoricalColumn[]::new));

List<String> valueColumnNames = getValueColumnNames(table, column2);
List<String> groupingColumnNames = groupingColumns.stream().map(_c -> _c.name()).collect(Collectors.toList());

for (String colName : valueColumnNames) {
pivotTable.addColumns(DoubleColumn.create(colName));
}
Table pivotTable = Table.create("Pivot: " + String.join(",", groupingColumnNames) + " x " + pivotColumn.name());

pivotTable.addColumns(groupingColumns.stream().map(_c -> _c.type().create(_c.name())).toArray(Column[]::new));

int valueIndex = table.columnIndex(column2);
int keyIndex = table.columnIndex(column1);
List<String> valueColumnNames = getValueColumnNames(table, pivotColumn);

String key;
if(multiAggregated){
for (String colName : valueColumnNames)
for(NumericColumn<?> aggColumn : aggregatedColumns) {
pivotTable.addColumns(DoubleColumn.create(colName + "." + aggColumn.name()));
}
}
else{
for (String colName : valueColumnNames) {
pivotTable.addColumns(DoubleColumn.create(colName));
}
}

for (TableSlice slice : tsg.getSlices()) {
key = String.valueOf(slice.get(0, keyIndex));
pivotTable.column(0).appendCell(key);

for (int i = 0; i < groupingColumns.size(); i++) {
String key = String.valueOf(slice.get(0, table.columnIndex(groupingColumns.get(i))));
pivotTable.column(i).appendCell(key);
}

Map<String, Double> valueMap =
getValueMap(column1, column2, values, valueIndex, slice, aggregateFunction);
getValueMap(groupingColumns, pivotColumn, aggregatedColumns, slice, aggregateFunction);

for (String columnName : valueColumnNames) {
Double aDouble = valueMap.get(columnName);
NumericColumn<?> pivotValueColumn = pivotTable.numberColumn(columnName);
if (aDouble == null) {
pivotValueColumn.appendMissing();
} else {
pivotValueColumn.appendObj(aDouble);
}
}
for (NumericColumn<?> aggregatedColumn: aggregatedColumns) {

String appendedColumnName;

if(multiAggregated){
appendedColumnName = columnName + "." + aggregatedColumn.name();
} else {
appendedColumnName = columnName;
}

NumericColumn<?> pivotValueColumn = pivotTable.numberColumn(appendedColumnName);

Double aDouble = valueMap.get(appendedColumnName);

if (aDouble == null) {
pivotValueColumn.appendMissing();
} else {
pivotValueColumn.appendObj(aDouble);
}
}
}

}

return pivotTable;
}

private static Map<String, Double> getValueMap(
CategoricalColumn<?> column1,
CategoricalColumn<?> column2,
NumericColumn<?> values,
int valueIndex,
List<CategoricalColumn<?>> groupingColumns,
CategoricalColumn<?> pivotColumn,
List<NumericColumn<?>> aggregatedColumns,
TableSlice slice,
AggregateFunction<?, ?> function) {

boolean multiAggregated = aggregatedColumns.size() > 1;
Table temp = slice.asTable();
Table summary = temp.summarize(values.name(), function).by(column1.name(), column2.name());
List<CategoricalColumn<?>> allKeyColumns = new LinkedList<>(groupingColumns);
allKeyColumns.add(pivotColumn);

List<String> aggregatedColumnNames = aggregatedColumns.stream().map(NumericColumn::name).collect(Collectors.toList());

Table summary = temp.summarize(aggregatedColumnNames, function).by(allKeyColumns.stream().map(CategoricalColumn::name).toArray(String[]::new));

Map<String, Double> valueMap = new HashMap<>();
NumericColumn<?> nc = summary.numberColumn(summary.columnCount() - 1);
for (int i = 0; i < summary.rowCount(); i++) {
valueMap.put(String.valueOf(summary.get(i, 1)), nc.getDouble(i));

if(multiAggregated){
for (int i = 0; i < summary.rowCount(); i++) {
for (int k = 0; k < aggregatedColumns.size(); k++) {
NumericColumn<?> nc = summary.numberColumn(groupingColumns.size() + k + 1);
valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())) + "." + aggregatedColumns.get(k).name(), nc.getDouble(i));
}
}
}
else{
NumericColumn<?> nc = summary.numberColumn(summary.columnCount() - 1);
for (int i = 0; i < summary.rowCount(); i++) {
valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())), nc.getDouble(i));
}
}

return valueMap;
}

Expand Down
84 changes: 68 additions & 16 deletions core/src/main/java/tech/tablesaw/api/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -864,35 +864,87 @@ public Table dropWhere(Selection selection) {
return newTable;
}


/**
* Returns a new column, where the first n columns are the groupingColumns. There are then p additional
* columns, which is the product of each unique value in the pivot column and aggregatedColumn. The
* values in each of the cells in these new columns are the result of applying the given AggregateFunction
* to the data in each of aggregatedColumn, grouped by the values of groupingColumn and pivotColumn.
*
* If more than one aggregatedColumn is provided then each is appended to each unique value of the pivot
* column in the format "{PivotColumnValue}.{AggregatedColumnName}
*
* @param groupingColumn
* @param pivotColumn
* @param aggregatedColumn
* @param aggregateFunction
* @return
*/
public Table pivot(
List<CategoricalColumn<?>> groupingColumns,
CategoricalColumn<?> pivotColumn,
List<NumericColumn<?>> aggregatedColumns,
AggregateFunction<?, ?> aggregateFunction) {
return PivotTable.pivot(this, groupingColumns, pivotColumn, aggregatedColumns, aggregateFunction);
}

/**
* Returns a new column, where the first n columns are the groupingColumns. There are then p additional
* columns, which is the product of each unique value in the pivot column and aggregatedColumn. The
* values in each of the cells in these new columns are the result of applying the given AggregateFunction
* to the data in each of aggregatedColumn, grouped by the values of groupingColumn and pivotColumn.
*
* If more than one aggregatedColumn is provided then each is appended to each unique value of the pivot
* column in the format "{PivotColumnValue}.{AggregatedColumnName}
*
* @param groupingColumnNames
* @param pivotColumnName
* @param aggregatedColumnNames
* @param aggregateFunction
* @return
*/
public Table pivot(
List<String> groupingColumnNames,
String pivotColumnName,
List<String> aggregatedColumnNames,
AggregateFunction<?, ?> aggregateFunction) {
return pivot(
groupingColumnNames.stream().map(this::categoricalColumn).collect(Collectors.toList()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codacy is complaining that you're doing Collectors.toList here when toList is already statically imported. I'm fine either way to want it as long as it's consistent

categoricalColumn(pivotColumnName),
aggregatedColumnNames.stream().map(this::numberColumn).collect(Collectors.toList()),
aggregateFunction);
}

/**
* Returns a pivot on this table, where: The first column contains unique values from the index
* column1 There are n additional columns, one for each unique value in column2 The values in each
* of the cells in these new columns are the result of applying the given AggregateFunction to the
* data in column3, grouped by the values of column1 and column2
* groupingColumn There are n additional columns, one for each unique value in the pivotColumn. The
* values in each of the cells in these new columns are the result of applying the given AggregateFunction
* to the data in the aggregatedColumn, grouped by the values of groupingColumn and pivotColumn
*/
public Table pivot(
CategoricalColumn<?> column1,
CategoricalColumn<?> column2,
NumericColumn<?> column3,
CategoricalColumn<?> groupingColumn,
CategoricalColumn<?> pivotColumn,
NumericColumn<?> aggregatedColumn,
AggregateFunction<?, ?> aggregateFunction) {
return PivotTable.pivot(this, column1, column2, column3, aggregateFunction);
return PivotTable.pivot(this, groupingColumn, pivotColumn, aggregatedColumn, aggregateFunction);
}


/**
* Returns a pivot on this table, where: The first column contains unique values from the index
* column1 There are n additional columns, one for each unique value in column2 The values in each
* of the cells in these new columns are the result of applying the given AggregateFunction to the
* data in column3, grouped by the values of column1 and column2
* groupingColumn There are n additional columns, one for each unique value in the pivotColumn The
* values in each of the cells in these new columns are the result of applying the given AggregateFunction
* to the data in the aggregatedColumn, grouped by the values of groupingColumn and pivotColumn
*/
public Table pivot(
String column1Name,
String column2Name,
String column3Name,
String groupingColumnName,
String pivotColumnName,
String aggregatedColumnName,
AggregateFunction<?, ?> aggregateFunction) {
return pivot(
categoricalColumn(column1Name),
categoricalColumn(column2Name),
numberColumn(column3Name),
categoricalColumn(groupingColumnName),
categoricalColumn(pivotColumnName),
numberColumn(pivotColumnName),
aggregateFunction);
}

Expand Down
75 changes: 74 additions & 1 deletion core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import org.junit.jupiter.api.Test;
import tech.tablesaw.api.Table;
import tech.tablesaw.io.csv.CsvReadOptions;
import java.util.List;

public class PivotTableTest {

/**
* Illustrate usage of pivot function with a single grouping, pivot and aggregated columns
* @throws Exception
*/
@Test
public void pivot() throws Exception {
public void pivotSingle() throws Exception {
Table t =
Table.read()
.csv(CsvReadOptions.builder("../data/bush.csv").missingValueIndicator(":").build());
Expand All @@ -30,4 +35,72 @@ public void pivot() throws Exception {
assertTrue(pivot.columnNames().contains("2004"));
assertEquals(6, pivot.rowCount());
}


@Test
public void pivotMultipleGroupAndAggregate() throws Exception {
Table t =
Table.read()
.csv(CsvReadOptions.builder("../data/baseball.csv").build());

Table pivot =
t.pivot(
List.of("Team","League"),
"Year",
List.of("RS","RA","W"),
AggregateFunctions.mean);

assertTrue(pivot.columnNames().contains("Team"));
assertTrue(pivot.columnNames().contains("League"));
assertTrue(pivot.columnNames().contains("2001.RS"));
assertTrue(pivot.columnNames().contains("2001.RA"));
assertTrue(pivot.columnNames().contains("2001.W"));
assertEquals(143, pivot.columnCount());
assertEquals(40, pivot.rowCount());
}

@Test
public void pivotMultipleGroup() throws Exception {
Table t =
Table.read()
.csv(CsvReadOptions.builder("../data/baseball.csv").build());

Table pivot =
t.pivot(
List.of("Team","League"),
"Year",
List.of("RS"),
AggregateFunctions.mean);

assertTrue(pivot.columnNames().contains("Team"));
assertTrue(pivot.columnNames().contains("League"));
assertTrue(pivot.columnNames().contains("2001"));
assertTrue(pivot.columnNames().contains("2002"));
assertTrue(pivot.columnNames().contains("2003"));
assertEquals(49, pivot.columnCount());
assertEquals(40, pivot.rowCount());
}

@Test
public void pivotMultipleAggregate() throws Exception {
Table t =
Table.read()
.csv(CsvReadOptions.builder("../data/baseball.csv").build());

Table pivot =
t.pivot(
List.of("League"),
"Year",
List.of("RS","RA","W"),
AggregateFunctions.mean);

assertTrue(!pivot.columnNames().contains("Team"));
assertTrue(pivot.columnNames().contains("League"));
assertTrue(pivot.columnNames().contains("2001.RS"));
assertTrue(pivot.columnNames().contains("2001.RA"));
assertTrue(pivot.columnNames().contains("2001.W"));
assertEquals(142, pivot.columnCount());
assertEquals(2, pivot.rowCount());
}

}