Skip to content

Commit

Permalink
change request type to ActionRequest BaseGetConfigTransportAction to …
Browse files Browse the repository at this point in the history
…fix class cast exception (#1221)

Signed-off-by: Hailong Cui <ihailong@amazon.com>
  • Loading branch information
Hailong-am committed Jun 5, 2024
1 parent 05b04b9 commit 9eed2ff
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ public AnomalyDetector(StreamInput input) throws IOException {
} else {
this.imputationOption = null;
}
this.recencyEmphasis = input.readInt();
this.seasonIntervals = input.readInt();
this.historyIntervals = input.readInt();
this.recencyEmphasis = input.readOptionalInt();
this.seasonIntervals = input.readOptionalInt();
this.historyIntervals = input.readOptionalInt();
if (input.readBoolean()) {
this.rules = input.readList(Rule::new);
}
Expand Down Expand Up @@ -333,9 +333,9 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
output.writeInt(recencyEmphasis);
output.writeInt(seasonIntervals);
output.writeInt(historyIntervals);
output.writeOptionalInt(recencyEmphasis);
output.writeOptionalInt(seasonIntervals);
output.writeOptionalInt(historyIntervals);
if (rules != null) {
output.writeBoolean(true);
output.writeList(rules);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.MultiGetItemResponse;
import org.opensearch.action.get.MultiGetRequest;
Expand Down Expand Up @@ -74,7 +75,7 @@
import com.google.common.collect.Sets;

public abstract class BaseGetConfigTransportAction<GetConfigResponseType extends ActionResponse, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, IndexType extends Enum<IndexType> & TimeSeriesIndex, IndexManagementType extends IndexManagement<IndexType>, TaskManagerType extends TaskManager<TaskCacheManagerType, TaskTypeEnum, TaskClass, IndexType, IndexManagementType>, ConfigType extends Config, EntityProfileActionType extends ActionType<EntityProfileResponse>, EntityProfileRunnerType extends EntityProfileRunner<EntityProfileActionType>, TaskProfileType extends TaskProfile<TaskClass>, ConfigProfileType extends ConfigProfile<TaskClass, TaskProfileType>, ProfileActionType extends ActionType<ProfileResponse>, TaskProfileRunnerType extends TaskProfileRunner<TaskClass, TaskProfileType>, ProfileRunnerType extends ProfileRunner<TaskCacheManagerType, TaskTypeEnum, TaskClass, IndexType, IndexManagementType, TaskProfileType, TaskManagerType, ConfigProfileType, ProfileActionType, TaskProfileRunnerType>>
extends HandledTransportAction<GetConfigRequest, GetConfigResponseType> {
extends HandledTransportAction<ActionRequest, GetConfigResponseType> {

private static final Logger LOG = LogManager.getLogger(BaseGetConfigTransportAction.class);

Expand Down Expand Up @@ -156,8 +157,9 @@ public BaseGetConfigTransportAction(
}

@Override
public void doExecute(Task task, GetConfigRequest request, ActionListener<GetConfigResponseType> actionListener) {
String configID = request.getConfigID();
public void doExecute(Task task, ActionRequest request, ActionListener<GetConfigResponseType> actionListener) {
GetConfigRequest getConfigRequest = GetConfigRequest.fromActionRequest(request);
String configID = getConfigRequest.getConfigID();
User user = ParseUtils.getUserContext(client);
ActionListener<GetConfigResponseType> listener = wrapRestActionListener(actionListener, FAIL_TO_GET_FORECASTER);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand All @@ -166,7 +168,7 @@ public void doExecute(Task task, GetConfigRequest request, ActionListener<GetCon
configID,
filterByEnabled,
listener,
(config) -> getExecute(request, listener),
(config) -> getExecute(getConfigRequest, listener),
client,
clusterService,
xContentRegistry,
Expand Down
49 changes: 49 additions & 0 deletions src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,28 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;

import org.junit.Assert;
import org.opensearch.ad.constant.ADCommonMessages;
import org.opensearch.ad.constant.ADCommonName;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder;
import org.opensearch.timeseries.AbstractTimeSeriesTest;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.common.exception.ValidationException;
Expand Down Expand Up @@ -901,4 +915,39 @@ public void testParseAnomalyDetector_withCustomIndex_withCustomResultIndexTTL()
AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null);
assertEquals(30, (int) parsedDetector.getCustomResultIndexTTL());
}

public void testSerializeAndDeserializeAnomalyDetector() throws IOException {
// register writer and reader for type Feature
Writeable.WriteableRegistry.registerWriter(Feature.class, (o, v) -> {
o.writeByte((byte) 23);
((Feature) v).writeTo(o);
});
Writeable.WriteableRegistry.registerReader((byte) 23, Feature::new);

// write to streamOutput
AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now());
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
detector.writeTo(bytesStreamOutput);

// register namedWriteables
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, BoolQueryBuilder.NAME, BoolQueryBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RangeQueryBuilder.NAME, RangeQueryBuilder::new));
namedWriteables
.add(
new NamedWriteableRegistry.Entry(
AggregationBuilder.class,
ValueCountAggregationBuilder.NAME,
ValueCountAggregationBuilder::new
)
);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
StreamInput input = new NamedWriteableAwareStreamInput(streamInput, new NamedWriteableRegistry(namedWriteables));

AnomalyDetector deserializedDetector = new AnomalyDetector(input);
Assert.assertEquals(deserializedDetector, detector);
Assert.assertEquals(deserializedDetector.getSeasonIntervals(), detector.getSeasonIntervals());
}
}

0 comments on commit 9eed2ff

Please sign in to comment.