Skip to content

Commit

Permalink
Inject NamedWriteableRegistry in AD node client (#1164)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Ohlsen <ohltyler@amazon.com>
  • Loading branch information
ohltyler committed Feb 21, 2024
1 parent 5b85720 commit 1507dd4
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;

public class AnomalyDetectionNodeClient implements AnomalyDetectionClient {
private final Client client;
private final NamedWriteableRegistry namedWriteableRegistry;

public AnomalyDetectionNodeClient(Client client) {
public AnomalyDetectionNodeClient(Client client, NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.namedWriteableRegistry = namedWriteableRegistry;
}

@Override
Expand All @@ -46,14 +49,18 @@ public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionL

// We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic
// ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins.
// Additionally, we need to inject the configured NamedWriteableRegistry so NamedWriteables (present in sub-fields of
// GetAnomalyDetectorResponse) are able to be re-serialized and prevent errors like the following:
// "can't read named writeable from StreamInput"
private ActionListener<GetAnomalyDetectorResponse> getAnomalyDetectorResponseActionListener(
ActionListener<GetAnomalyDetectorResponse> listener
) {
ActionListener<GetAnomalyDetectorResponse> internalListener = ActionListener.wrap(getAnomalyDetectorResponse -> {
listener.onResponse(getAnomalyDetectorResponse);
}, listener::onFailure);
ActionListener<GetAnomalyDetectorResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse);
GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse
.fromActionResponse(actionResponse, this.namedWriteableRegistry);
return response;
});
return actionListener;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.opensearch.ad.model.EntityProfile;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -218,16 +220,21 @@ public AnomalyDetector getDetector() {
return detector;
}

public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) {
public static GetAnomalyDetectorResponse fromActionResponse(
ActionResponse actionResponse,
NamedWriteableRegistry namedWriteableRegistry
) {
if (actionResponse instanceof GetAnomalyDetectorResponse) {
return (GetAnomalyDetectorResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos);
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new GetAnomalyDetectorResponse(input);
}
InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()));
NamedWriteableAwareStreamInput namedWriteableAwareInput = new NamedWriteableAwareStreamInput(input, namedWriteableRegistry);
return new GetAnomalyDetectorResponse(namedWriteableAwareInput);
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.client.Client;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
Expand All @@ -64,7 +65,7 @@ public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTest
@Before
public void setup() {
clientSpy = spy(client());
adClient = new AnomalyDetectionNodeClient(clientSpy);
adClient = new AnomalyDetectionNodeClient(clientSpy, mock(NamedWriteableRegistry.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import java.util.Collection;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.InternalSettingsPlugin;
Expand Down Expand Up @@ -76,6 +79,21 @@ public void testSerializationWithJobAndTask() throws IOException {
assertEquals(response.getDetector(), parsedResponse.getDetector());
}

public void testFromActionResponse() throws IOException {
GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(true, true);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry());

GetAnomalyDetectorResponse reserializedResponse = GetAnomalyDetectorResponse
.fromActionResponse((ActionResponse) response, writableRegistry());
assertEquals(response, reserializedResponse);

ActionResponse invalidActionResponse = new TestActionResponse(input);
assertThrows(Exception.class, () -> GetAnomalyDetectorResponse.fromActionResponse(invalidActionResponse, writableRegistry()));

}

private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean returnJob, boolean returnTask) throws IOException {
GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse(
randomLong(),
Expand All @@ -95,4 +113,17 @@ private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean retu
);
return response;
}

// A test ActionResponse class with an inactive writeTo class. Used to ensure exceptions
// are thrown when parsing implementations of such class.
private class TestActionResponse extends ActionResponse {
public TestActionResponse(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
return;
}
}
}

0 comments on commit 1507dd4

Please sign in to comment.