Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Changed
- Add CompletionStage variants to methods in the Client Interface and default to ActionListener impl ([#18998](https://github.com/opensearch-project/OpenSearch/pull/18998))
- IllegalArgumentException when scroll ID references a node not found in Cluster ([#19031](https://github.com/opensearch-project/OpenSearch/pull/19031))
- Adding ScriptedAvg class to painless spi to allowlist usage from plugins ([#19006](https://github.com/opensearch-project/OpenSearch/pull/19006))

### Fixed
- Fix unnecessary refreshes on update preparation failures ([#15261](https://github.com/opensearch-project/OpenSearch/issues/15261))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,8 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep
o.writeByte((byte) 27);
o.writeSemverRange((SemverRange) v);
});
// Have registered ScriptedAvg class with byte 28 in Streamables.java, so that we do not need the implementation reside in the
// server module
WRITERS = Collections.unmodifiableMap(writers);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ class org.opensearch.index.query.IntervalFilterScript$Interval {
class org.opensearch.script.ScoreScript$ExplanationHolder {
void set(String)
}

class org.opensearch.search.aggregations.metrics.ScriptedAvg {
(double,long)
double getSum()
long getCount()
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable.WriteableRegistry;
import org.opensearch.search.aggregations.metrics.ScriptedAvg;

/**
* This utility class registers generic types for streaming over the wire using
Expand Down Expand Up @@ -45,6 +46,12 @@ private static void registerWriters() {
o.writeByte((byte) 22);
((GeoPoint) v).writeTo(o);
});

WriteableRegistry.registerWriter(ScriptedAvg.class, (o, v) -> {
o.writeByte((byte) 28);
((ScriptedAvg) v).writeTo(o);
});

}

/**
Expand All @@ -55,5 +62,6 @@ private static void registerWriters() {
private static void registerReaders() {
/* {@link GeoPoint} */
WriteableRegistry.registerReader(Byte.valueOf((byte) 22), GeoPoint::new);
WriteableRegistry.registerReader(Byte.valueOf((byte) 28), ScriptedAvg::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,21 @@ public InternalAvg reduce(List<InternalAggregation> aggregations, ReduceContext
for (InternalAggregation aggregation : aggregations) {
if (aggregation instanceof InternalScriptedMetric) {
// If using InternalScriptedMetric in place of InternalAvg
Object value = ((InternalScriptedMetric) aggregation).aggregation();
if (value instanceof ScriptedAvg scriptedAvg) {
count += scriptedAvg.getCount();
kahanSummation.add(scriptedAvg.getSum());
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] avg aggregation. Expected ScriptedAvg "
+ "but received ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
List<Object> aggList = ((InternalScriptedMetric) aggregation).aggregationsList();
for (Object value : aggList) {
if (value instanceof ScriptedAvg scriptedAvg) {
count += scriptedAvg.getCount();
kahanSummation.add(scriptedAvg.getSum());
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] avg aggregation. Expected ScriptedAvg "
+ "but received ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
}
}
} else {
// Original handling for InternalAvg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
for (InternalAggregation aggregation : aggregations) {
if (aggregation instanceof InternalScriptedMetric) {
// If using InternalScriptedMetric in place of InternalValueCount
Object value = ((InternalScriptedMetric) aggregation).aggregation();
if (value instanceof Number) {
valueCount += ((Number) value).longValue();
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
List<Object> aggList = ((InternalScriptedMetric) aggregation).aggregationsList();
for (Object value : aggList) {
if (value instanceof Number) {
valueCount += ((Number) value).longValue();
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
}
}
} else {
// Original handling for InternalValueCount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,7 @@
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
Expand Down Expand Up @@ -79,4 +58,5 @@ public double getSum() {
public long getCount() {
return count;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
@Override
public InternalAggregation buildAggregation(long owningBucketOrdinal) {
Object result = aggStateForResult(owningBucketOrdinal).combine();
StreamOutput.checkWriteable(result);
if (result.getClass() != ScriptedAvg.class) StreamOutput.checkWriteable(result);
return new InternalScriptedMetric(name, singletonList(result), reduceScript, metadata());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ public void testReduceWithScriptedMetric() {
// Add ScriptedMetric with ScriptedAvg object
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
when(scriptedMetric1.getName()).thenReturn(name);
ScriptedAvg scriptedAvg = new ScriptedAvg(100.0, 20L);
when(scriptedMetric1.aggregation()).thenReturn(scriptedAvg);
List<Object> aggList = new ArrayList<>();
aggList.add(new ScriptedAvg(100.0, 20L));
when(scriptedMetric1.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric1);

InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
Expand Down Expand Up @@ -175,7 +176,9 @@ public void testReduceWithScriptedMetricInvalidType() {
// Add ScriptedMetric with invalid return type (String instead of double[])
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
when(scriptedMetric1.getName()).thenReturn(name);
when(scriptedMetric1.aggregation()).thenReturn("invalid_type");
List<Object> aggList = new ArrayList<>();
aggList.add("invalid_type");
when(scriptedMetric1.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric1);

InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
Expand All @@ -199,7 +202,9 @@ public void testReduceWithScriptedMetricInvalidArrayLength() {
// Add ScriptedMetric with double array of wrong length (should be 2)
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
when(scriptedMetric.getName()).thenReturn(name);
when(scriptedMetric.aggregation()).thenReturn(new double[] { 100.0, 20.0, 30.0 }); // length 3 instead of 2
List<Object> aggList = new ArrayList<>();
aggList.add(new double[] { 100.0, 20.0, 30.0 }); // Add double array to list
when(scriptedMetric.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric);

InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,23 @@ public void testReduceWithScriptedMetric() {

// Add ScriptedMetric with Long value
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
when(scriptedMetric1.aggregation()).thenReturn(20L);
List<Object> aggList1 = new ArrayList<>();
aggList1.add(20L);
when(scriptedMetric1.aggregationsList()).thenReturn(aggList1);
aggregations.add(scriptedMetric1);

// Add ScriptedMetric with Integer value
InternalScriptedMetric scriptedMetric2 = mock(InternalScriptedMetric.class);
when(scriptedMetric2.aggregation()).thenReturn(30);
List<Object> aggList2 = new ArrayList<>();
aggList2.add(30);
when(scriptedMetric2.aggregationsList()).thenReturn(aggList2);
aggregations.add(scriptedMetric2);

// Add ScriptedMetric with Double value
InternalScriptedMetric scriptedMetric3 = mock(InternalScriptedMetric.class);
when(scriptedMetric3.aggregation()).thenReturn(10.5);
List<Object> aggList3 = new ArrayList<>();
aggList3.add(10.5);
when(scriptedMetric3.aggregationsList()).thenReturn(aggList3);
aggregations.add(scriptedMetric3);

InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
Expand All @@ -92,6 +98,7 @@ public void testReduceWithScriptedMetric() {
}

public void testReduceWithInternalValueCountOnly() {
// This test remains unchanged as it doesn't use ScriptedMetric
String name = "test_value_count";
List<InternalAggregation> aggregations = new ArrayList<>();

Expand All @@ -116,7 +123,9 @@ public void testReduceWithScriptedMetricInvalidValue() {

// Add ScriptedMetric with invalid value type (String instead of Number)
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
when(scriptedMetric.aggregation()).thenReturn("invalid_value");
List<Object> aggList = new ArrayList<>();
aggList.add("invalid_value");
when(scriptedMetric.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric);

InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
Expand All @@ -133,6 +142,29 @@ public void testReduceWithScriptedMetricInvalidValue() {
);
}

public void testReduceWithMultipleValuesInList() {
String name = "test_scripted_metric";
List<InternalAggregation> aggregations = new ArrayList<>();

// Add regular InternalValueCount
aggregations.add(new InternalValueCount(name, 50L, null));

// Add ScriptedMetric with multiple values in the list
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
List<Object> aggList = new ArrayList<>();
aggList.add(20L);
aggList.add(30);
aggList.add(10.5);
when(scriptedMetric.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric);

InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
InternalValueCount reduced = (InternalValueCount) valueCount.reduce(aggregations, null);

// Expected: 50 + 20 + 30 + 10 = 110
assertEquals(110L, reduced.getValue());
}

@Override
protected InternalValueCount mutateInstance(InternalValueCount instance) {
String name = instance.getName();
Expand Down
Loading