Skip to content

Commit

Permalink
GroupBy.rowNumber() #16
Browse files Browse the repository at this point in the history
  • Loading branch information
andrus committed Mar 31, 2019
1 parent 1cc0a58 commit e30fcb2
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 23 deletions.
12 changes: 2 additions & 10 deletions dflib/src/main/java/com/nhl/dflib/ColumnDataFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.nhl.dflib.map.ValueMapper;
import com.nhl.dflib.row.CrossColumnRowProxy;
import com.nhl.dflib.row.RowProxy;
import com.nhl.dflib.seq.Sequences;
import com.nhl.dflib.series.ArraySeries;
import com.nhl.dflib.series.ColumnMappedSeries;
import com.nhl.dflib.series.HeadSeries;
Expand All @@ -41,15 +42,6 @@ public ColumnDataFrame(Index columnsIndex, Series[] dataColumns) {
this.dataColumns = Objects.requireNonNull(dataColumns);
}

protected static Integer[] rowNumberSequence(int h) {
Integer[] rn = new Integer[h];
for (int i = 0; i < h; i++) {
rn[i] = i;
}

return rn;
}

@Override
public int height() {
return dataColumns.length > 0 ? dataColumns[0].size() : 0;
Expand All @@ -72,7 +64,7 @@ public <T> Series<T> getColumn(String name) {

@Override
public DataFrame addRowNumber(String columnName) {
return addColumn(columnName, new ArraySeries<>(rowNumberSequence(height())));
return addColumn(columnName, new ArraySeries<>(Sequences.numberSequence(height())));
}

@Override
Expand Down
32 changes: 32 additions & 0 deletions dflib/src/main/java/com/nhl/dflib/GroupBy.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import com.nhl.dflib.aggregate.Aggregator;
import com.nhl.dflib.aggregate.ColumnAggregator;
import com.nhl.dflib.concat.VConcat;
import com.nhl.dflib.join.JoinType;
import com.nhl.dflib.map.RowToValueMapper;
import com.nhl.dflib.row.RowProxy;
import com.nhl.dflib.seq.Sequences;
import com.nhl.dflib.series.ArraySeries;
import com.nhl.dflib.series.IndexedSeries;
import com.nhl.dflib.sort.IndexSorter;
import com.nhl.dflib.sort.Sorters;
Expand All @@ -18,6 +22,8 @@

public class GroupBy {

private static final Index TWO_COLUMN_INDEX = Index.forLabels("0", "1");

private DataFrame ungrouped;
private Map<Object, Series<Integer>> groupsIndex;
private Map<Object, DataFrame> resolvedGroups;
Expand Down Expand Up @@ -47,6 +53,32 @@ public DataFrame getGroup(Object key) {
return resolvedGroups.computeIfAbsent(key, this::resolveGroup);
}

/**
* A "window" function that converts this grouping into a Series that provides row numbers of each row within their
* group. The order of row numbers corresponds to the order of rows in the original DataFrame that was used to
* build the grouping. So the Series can be added back to the original DataFrame.
*
* @return a new Series object with row numbers of each row within their group. The overall order matches the order
* of the original DataFrame that was used to build the grouping.
*/
public Series<Integer> rowNumbers() {

DataFrame[] numberedIndex = new DataFrame[groupsIndex.size()];

int i = 0;
for (Series<Integer> s : groupsIndex.values()) {

Series<?>[] indexes = new Series[2];
indexes[0] = s;
indexes[1] = new ArraySeries<>(Sequences.numberSequence(s.size()));

numberedIndex[i] = new ColumnDataFrame(TWO_COLUMN_INDEX, indexes);
i++;
}

return VConcat.concat(JoinType.inner, numberedIndex).sort(0, true).getColumn(1);
}

public <V extends Comparable<? super V>> GroupBy sort(RowToValueMapper<V> sortKeyExtractor) {

Comparator<RowProxy> comparator = Sorters.sorter(sortKeyExtractor);
Expand Down
6 changes: 4 additions & 2 deletions dflib/src/main/java/com/nhl/dflib/Series.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import com.nhl.dflib.series.EmptySeries;
import com.nhl.dflib.series.RangeSeries;

import static java.util.Arrays.asList;

public interface Series<T> {

static <T> Series<T> forData(T... data) {
Expand Down Expand Up @@ -49,10 +51,10 @@ default Series<T> concat(Series<? extends T>... other) {
return this;
}

Series<? extends T>[] combined = new Series[other.length + 1];
Series<T>[] combined = new Series[other.length + 1];
combined[0] = this;
System.arraycopy(other, 0, combined, 1, other.length);

return SeriesConcat.concat(combined);
return SeriesConcat.concat(asList(combined));
}
}
15 changes: 7 additions & 8 deletions dflib/src/main/java/com/nhl/dflib/concat/SeriesConcat.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@

public class SeriesConcat {

public static <T> Series<T> concat(Series<? extends T>... concat) {

public static <T> Series<T> concat(Iterable<Series<T>> concat) {
int h = 0;
int n = concat.length;
for (Series<?> s : concat) {
for (Series<? extends T> s : concat) {
h += s.size();
}

T[] data = (T[]) new Object[h];
for (int i = 0, ai = 0; i < n; i++) {
int len = concat[i].size();
concat[i].copyTo(data, 0, ai, len);
ai += len;
int offset = 0;
for (Series<? extends T> s : concat) {
int len = s.size();
s.copyTo(data, 0, offset, len);
offset += len;
}

return new ArraySeries<>(data);
Expand Down
3 changes: 0 additions & 3 deletions dflib/src/main/java/com/nhl/dflib/groupby/Grouper.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.nhl.dflib.DataFrame;
import com.nhl.dflib.GroupBy;
import com.nhl.dflib.Index;
import com.nhl.dflib.Series;
import com.nhl.dflib.map.Hasher;
import com.nhl.dflib.row.RowProxy;
Expand All @@ -27,8 +26,6 @@ public GroupBy group(DataFrame df) {
// Intentionally using generics-free map to be able to reset the internal object and avoid copying the map
Map groups = new LinkedHashMap();

Index columns = df.getColumnsIndex();

int i = 0;
for (RowProxy r : df) {
Object key = hasher.map(r);
Expand Down
13 changes: 13 additions & 0 deletions dflib/src/main/java/com/nhl/dflib/seq/Sequences.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.nhl.dflib.seq;

public class Sequences {

public static Integer[] numberSequence(int h) {
Integer[] rn = new Integer[h];
for (int i = 0; i < h; i++) {
rn[i] = i;
}

return rn;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.nhl.dflib;

import com.nhl.dflib.unit.SeriesAsserts;
import org.junit.Test;

public class DataFrame_GroupBy_WindowFuncsTest extends BaseDataFrameTest {

@Test
public void testGroupBy_RowNumbers0() {
Index i = Index.forLabels("a", "b", "c");
DataFrame df = createDf(i,
1, "x", "m",
2, "y", "n",
1, "z", "k",
0, "a", "f",
1, "x", "s");

Series<Integer> rn = df.group("a").rowNumbers();
new SeriesAsserts(rn).expectData(0, 0, 1, 0, 2);
}

@Test
public void testGroupBy_RowNumbers1() {
Index i = Index.forLabels("a", "b", "c");
DataFrame df = createDf(i,
3, "x", "m",
2, "y", "n",
1, "z", "k",
0, "a", "f",
-1, "x", "s");

Series<Integer> rn = df.group("a").rowNumbers();
new SeriesAsserts(rn).expectData(0, 0, 0, 0, 0);
}

@Test
public void testGroupBy_RowNumbers2() {
Index i = Index.forLabels("a", "b", "c");
DataFrame df = createDf(i,
3, "x", "m",
0, "y", "n",
3, "z", "k",
3, "a", "f",
1, "x", "s");

Series<Integer> rn = df.group("a").rowNumbers();
new SeriesAsserts(rn).expectData(0, 0, 1, 2, 0);
}
}

0 comments on commit e30fcb2

Please sign in to comment.