Skip to content

Commit

Permalink
Track memory in rate aggregation function (#106730)
Browse files Browse the repository at this point in the history
We should track the memory usage of the individual state in the rate 
aggregation function.

Relates #106703
  • Loading branch information
dnhatn committed Mar 26, 2024
1 parent 59354e3 commit cdb2e58
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public GroupingAggregatorImplementer(
this.createParameters = init.getParameters()
.stream()
.map(Parameter::from)
.filter(f -> false == f.type().equals(BIG_ARRAYS))
.filter(f -> false == f.type().equals(BIG_ARRAYS) && false == f.type().equals(DRIVER_CONTEXT))
.collect(Collectors.toList());

this.implementation = ClassName.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.compute.ann.GroupingAggregator;
Expand All @@ -35,9 +36,9 @@
@IntermediateState(name = "resets", type = "DOUBLE") }
)
public class RateDoubleAggregator {
public static DoubleRateGroupingState initGrouping(BigArrays bigArrays, long unitInMillis) {
// TODO: pass BlockFactory instead bigArrays so we can use the breaker
return new DoubleRateGroupingState(bigArrays, unitInMillis);

public static DoubleRateGroupingState initGrouping(DriverContext driverContext, long unitInMillis) {
return new DoubleRateGroupingState(driverContext.bigArrays(), driverContext.breaker(), unitInMillis);
}

public static void combine(DoubleRateGroupingState current, int groupId, long timestamp, double value) {
Expand Down Expand Up @@ -68,7 +69,7 @@ public static Block evaluateFinal(DoubleRateGroupingState state, IntVector selec
return state.evaluateFinal(selected, driverContext.blockFactory());
}

private static class DoubleRateState implements Accountable {
private static class DoubleRateState {
static final long BASE_RAM_USAGE = RamUsageEstimator.sizeOfObject(DoubleRateState.class);
final long[] timestamps; // descending order
final double[] values;
Expand Down Expand Up @@ -101,19 +102,23 @@ int entries() {
return timestamps.length;
}

@Override
public long ramBytesUsed() {
return BASE_RAM_USAGE;
static long bytesUsed(int entries) {
var ts = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * entries);
var vs = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * entries);
return BASE_RAM_USAGE + ts + vs;
}
}

public static final class DoubleRateGroupingState implements Releasable, Accountable, GroupingAggregatorState {
private ObjectArray<DoubleRateState> states;
private final long unitInMillis;
private final BigArrays bigArrays;
private final CircuitBreaker breaker;
private long stateBytes; // for individual states

DoubleRateGroupingState(BigArrays bigArrays, long unitInMillis) {
DoubleRateGroupingState(BigArrays bigArrays, CircuitBreaker breaker, long unitInMillis) {
this.bigArrays = bigArrays;
this.breaker = breaker;
this.states = bigArrays.newObjectArray(1);
this.unitInMillis = unitInMillis;
}
Expand All @@ -122,16 +127,25 @@ void ensureCapacity(int groupId) {
states = bigArrays.grow(states, groupId + 1);
}

void adjustBreaker(long bytes) {
breaker.addEstimateBytesAndMaybeBreak(bytes, "<<rate aggregation>>");
stateBytes += bytes;
assert stateBytes >= 0 : stateBytes;
}

void append(int groupId, long timestamp, double value) {
ensureCapacity(groupId);
var state = states.get(groupId);
if (state == null) {
adjustBreaker(DoubleRateState.bytesUsed(1));
state = new DoubleRateState(new long[] { timestamp }, new double[] { value });
states.set(groupId, state);
} else {
if (state.entries() == 1) {
adjustBreaker(DoubleRateState.bytesUsed(2));
state = new DoubleRateState(new long[] { state.timestamps[0], timestamp }, new double[] { state.values[0], value });
states.set(groupId, state);
adjustBreaker(-DoubleRateState.bytesUsed(1)); // old state
} else {
state.append(timestamp, value);
}
Expand All @@ -147,6 +161,7 @@ void combine(int groupId, LongBlock timestamps, DoubleBlock values, double reset
ensureCapacity(groupId);
var state = states.get(groupId);
if (state == null) {
adjustBreaker(DoubleRateState.bytesUsed(valueCount));
state = new DoubleRateState(valueCount);
states.set(groupId, state);
// TODO: add bulk_copy to Block
Expand All @@ -155,9 +170,11 @@ void combine(int groupId, LongBlock timestamps, DoubleBlock values, double reset
state.values[i] = values.getDouble(firstIndex + i);
}
} else {
adjustBreaker(DoubleRateState.bytesUsed(state.entries() + valueCount));
var newState = new DoubleRateState(state.entries() + valueCount);
states.set(groupId, newState);
merge(state, newState, firstIndex, valueCount, timestamps, values);
adjustBreaker(-DoubleRateState.bytesUsed(state.entries())); // old state
}
state.reset += reset;
}
Expand Down Expand Up @@ -193,12 +210,12 @@ void merge(DoubleRateState curr, DoubleRateState dst, int firstIndex, int rightC

@Override
public long ramBytesUsed() {
return states.ramBytesUsed();
return states.ramBytesUsed() + stateBytes;
}

@Override
public void close() {
Releasables.close(states);
Releasables.close(states, () -> adjustBreaker(-stateBytes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.compute.ann.GroupingAggregator;
Expand Down Expand Up @@ -36,9 +37,9 @@
@IntermediateState(name = "resets", type = "DOUBLE") }
)
public class RateIntAggregator {
public static IntRateGroupingState initGrouping(BigArrays bigArrays, long unitInMillis) {
// TODO: pass BlockFactory instead bigArrays so we can use the breaker
return new IntRateGroupingState(bigArrays, unitInMillis);

public static IntRateGroupingState initGrouping(DriverContext driverContext, long unitInMillis) {
return new IntRateGroupingState(driverContext.bigArrays(), driverContext.breaker(), unitInMillis);
}

public static void combine(IntRateGroupingState current, int groupId, long timestamp, int value) {
Expand Down Expand Up @@ -69,7 +70,7 @@ public static Block evaluateFinal(IntRateGroupingState state, IntVector selected
return state.evaluateFinal(selected, driverContext.blockFactory());
}

private static class IntRateState implements Accountable {
private static class IntRateState {
static final long BASE_RAM_USAGE = RamUsageEstimator.sizeOfObject(IntRateState.class);
final long[] timestamps; // descending order
final int[] values;
Expand Down Expand Up @@ -102,19 +103,23 @@ int entries() {
return timestamps.length;
}

@Override
public long ramBytesUsed() {
return BASE_RAM_USAGE;
static long bytesUsed(int entries) {
var ts = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * entries);
var vs = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Integer.BYTES * entries);
return BASE_RAM_USAGE + ts + vs;
}
}

public static final class IntRateGroupingState implements Releasable, Accountable, GroupingAggregatorState {
private ObjectArray<IntRateState> states;
private final long unitInMillis;
private final BigArrays bigArrays;
private final CircuitBreaker breaker;
private long stateBytes; // for individual states

IntRateGroupingState(BigArrays bigArrays, long unitInMillis) {
IntRateGroupingState(BigArrays bigArrays, CircuitBreaker breaker, long unitInMillis) {
this.bigArrays = bigArrays;
this.breaker = breaker;
this.states = bigArrays.newObjectArray(1);
this.unitInMillis = unitInMillis;
}
Expand All @@ -123,16 +128,25 @@ void ensureCapacity(int groupId) {
states = bigArrays.grow(states, groupId + 1);
}

void adjustBreaker(long bytes) {
breaker.addEstimateBytesAndMaybeBreak(bytes, "<<rate aggregation>>");
stateBytes += bytes;
assert stateBytes >= 0 : stateBytes;
}

void append(int groupId, long timestamp, int value) {
ensureCapacity(groupId);
var state = states.get(groupId);
if (state == null) {
adjustBreaker(IntRateState.bytesUsed(1));
state = new IntRateState(new long[] { timestamp }, new int[] { value });
states.set(groupId, state);
} else {
if (state.entries() == 1) {
adjustBreaker(IntRateState.bytesUsed(2));
state = new IntRateState(new long[] { state.timestamps[0], timestamp }, new int[] { state.values[0], value });
states.set(groupId, state);
adjustBreaker(-IntRateState.bytesUsed(1)); // old state
} else {
state.append(timestamp, value);
}
Expand All @@ -148,6 +162,7 @@ void combine(int groupId, LongBlock timestamps, IntBlock values, double reset, i
ensureCapacity(groupId);
var state = states.get(groupId);
if (state == null) {
adjustBreaker(IntRateState.bytesUsed(valueCount));
state = new IntRateState(valueCount);
states.set(groupId, state);
// TODO: add bulk_copy to Block
Expand All @@ -156,9 +171,11 @@ void combine(int groupId, LongBlock timestamps, IntBlock values, double reset, i
state.values[i] = values.getInt(firstIndex + i);
}
} else {
adjustBreaker(IntRateState.bytesUsed(state.entries() + valueCount));
var newState = new IntRateState(state.entries() + valueCount);
states.set(groupId, newState);
merge(state, newState, firstIndex, valueCount, timestamps, values);
adjustBreaker(-IntRateState.bytesUsed(state.entries())); // old state
}
state.reset += reset;
}
Expand Down Expand Up @@ -194,12 +211,12 @@ void merge(IntRateState curr, IntRateState dst, int firstIndex, int rightCount,

@Override
public long ramBytesUsed() {
return states.ramBytesUsed();
return states.ramBytesUsed() + stateBytes;
}

@Override
public void close() {
Releasables.close(states);
Releasables.close(states, () -> adjustBreaker(-stateBytes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.compute.ann.GroupingAggregator;
Expand All @@ -35,9 +36,9 @@
@IntermediateState(name = "resets", type = "DOUBLE") }
)
public class RateLongAggregator {
public static LongRateGroupingState initGrouping(BigArrays bigArrays, long unitInMillis) {
// TODO: pass BlockFactory instead bigArrays so we can use the breaker
return new LongRateGroupingState(bigArrays, unitInMillis);

public static LongRateGroupingState initGrouping(DriverContext driverContext, long unitInMillis) {
return new LongRateGroupingState(driverContext.bigArrays(), driverContext.breaker(), unitInMillis);
}

public static void combine(LongRateGroupingState current, int groupId, long timestamp, long value) {
Expand Down Expand Up @@ -68,7 +69,7 @@ public static Block evaluateFinal(LongRateGroupingState state, IntVector selecte
return state.evaluateFinal(selected, driverContext.blockFactory());
}

private static class LongRateState implements Accountable {
private static class LongRateState {
static final long BASE_RAM_USAGE = RamUsageEstimator.sizeOfObject(LongRateState.class);
final long[] timestamps; // descending order
final long[] values;
Expand Down Expand Up @@ -101,19 +102,23 @@ int entries() {
return timestamps.length;
}

@Override
public long ramBytesUsed() {
return BASE_RAM_USAGE;
static long bytesUsed(int entries) {
var ts = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * entries);
var vs = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * entries);
return BASE_RAM_USAGE + ts + vs;
}
}

public static final class LongRateGroupingState implements Releasable, Accountable, GroupingAggregatorState {
private ObjectArray<LongRateState> states;
private final long unitInMillis;
private final BigArrays bigArrays;
private final CircuitBreaker breaker;
private long stateBytes; // for individual states

LongRateGroupingState(BigArrays bigArrays, long unitInMillis) {
LongRateGroupingState(BigArrays bigArrays, CircuitBreaker breaker, long unitInMillis) {
this.bigArrays = bigArrays;
this.breaker = breaker;
this.states = bigArrays.newObjectArray(1);
this.unitInMillis = unitInMillis;
}
Expand All @@ -122,16 +127,25 @@ void ensureCapacity(int groupId) {
states = bigArrays.grow(states, groupId + 1);
}

void adjustBreaker(long bytes) {
breaker.addEstimateBytesAndMaybeBreak(bytes, "<<rate aggregation>>");
stateBytes += bytes;
assert stateBytes >= 0 : stateBytes;
}

void append(int groupId, long timestamp, long value) {
ensureCapacity(groupId);
var state = states.get(groupId);
if (state == null) {
adjustBreaker(LongRateState.bytesUsed(1));
state = new LongRateState(new long[] { timestamp }, new long[] { value });
states.set(groupId, state);
} else {
if (state.entries() == 1) {
adjustBreaker(LongRateState.bytesUsed(2));
state = new LongRateState(new long[] { state.timestamps[0], timestamp }, new long[] { state.values[0], value });
states.set(groupId, state);
adjustBreaker(-LongRateState.bytesUsed(1)); // old state
} else {
state.append(timestamp, value);
}
Expand All @@ -147,6 +161,7 @@ void combine(int groupId, LongBlock timestamps, LongBlock values, double reset,
ensureCapacity(groupId);
var state = states.get(groupId);
if (state == null) {
adjustBreaker(LongRateState.bytesUsed(valueCount));
state = new LongRateState(valueCount);
states.set(groupId, state);
// TODO: add bulk_copy to Block
Expand All @@ -155,9 +170,11 @@ void combine(int groupId, LongBlock timestamps, LongBlock values, double reset,
state.values[i] = values.getLong(firstIndex + i);
}
} else {
adjustBreaker(LongRateState.bytesUsed(state.entries() + valueCount));
var newState = new LongRateState(state.entries() + valueCount);
states.set(groupId, newState);
merge(state, newState, firstIndex, valueCount, timestamps, values);
adjustBreaker(-LongRateState.bytesUsed(state.entries())); // old state
}
state.reset += reset;
}
Expand Down Expand Up @@ -193,12 +210,12 @@ void merge(LongRateState curr, LongRateState dst, int firstIndex, int rightCount

@Override
public long ramBytesUsed() {
return states.ramBytesUsed();
return states.ramBytesUsed() + stateBytes;
}

@Override
public void close() {
Releasables.close(states);
Releasables.close(states, () -> adjustBreaker(-stateBytes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public RateDoubleGroupingAggregatorFunction(List<Integer> channels,

public static RateDoubleGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext, long unitInMillis) {
return new RateDoubleGroupingAggregatorFunction(channels, RateDoubleAggregator.initGrouping(driverContext.bigArrays(), unitInMillis), driverContext, unitInMillis);
return new RateDoubleGroupingAggregatorFunction(channels, RateDoubleAggregator.initGrouping(driverContext, unitInMillis), driverContext, unitInMillis);
}

public static List<IntermediateStateDesc> intermediateStateDesc() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public RateIntGroupingAggregatorFunction(List<Integer> channels,

public static RateIntGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext, long unitInMillis) {
return new RateIntGroupingAggregatorFunction(channels, RateIntAggregator.initGrouping(driverContext.bigArrays(), unitInMillis), driverContext, unitInMillis);
return new RateIntGroupingAggregatorFunction(channels, RateIntAggregator.initGrouping(driverContext, unitInMillis), driverContext, unitInMillis);
}

public static List<IntermediateStateDesc> intermediateStateDesc() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public RateLongGroupingAggregatorFunction(List<Integer> channels,

public static RateLongGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext, long unitInMillis) {
return new RateLongGroupingAggregatorFunction(channels, RateLongAggregator.initGrouping(driverContext.bigArrays(), unitInMillis), driverContext, unitInMillis);
return new RateLongGroupingAggregatorFunction(channels, RateLongAggregator.initGrouping(driverContext, unitInMillis), driverContext, unitInMillis);
}

public static List<IntermediateStateDesc> intermediateStateDesc() {
Expand Down

0 comments on commit cdb2e58

Please sign in to comment.