Skip to content

Commit

Permalink
Bug fixes and refactoring (#38)
Browse files Browse the repository at this point in the history
* Bug fixes and refactoring

1. Shallow clone SearchResponse in SearchActionFilter. The only
   properties that we're currently changing are SearchHit and
   timeTookMillis. This is also less brittle as we don't try to
   implement deserialization ourselves.
2. KendraIntelligentRanker doesn't transform if the required client
   settings (endpoint + execution plan ID) are missing.
3. A lot of refactoring to prepare for other transformers. Where we
   previously had logic of the form `if kendra_intelligent_ranking then
   do Kendra intelligent ranking`, now the APIs support arbitrary named
   transformers. This required updates to tests and configuration code.
4. Added tests and fixes for SearchConfigurationExtBuilder.

Signed-off-by: Michael Froh <froh@amazon.com>

* Address PR comments from myself

Looking at the PR on GitHub, I spotted a few mistakes and possible
improvements.

Signed-off-by: Michael Froh <froh@amazon.com>

* Address comments by noCharger

Signed-off-by: Michael Froh <froh@amazon.com>

Signed-off-by: Michael Froh <froh@amazon.com>
  • Loading branch information
msfroh committed Dec 19, 2022
1 parent d9b44e7 commit d8ba75b
Show file tree
Hide file tree
Showing 17 changed files with 611 additions and 353 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.opensearch.action.support.ActionFilter;
import org.opensearch.client.Client;
Expand All @@ -30,14 +31,15 @@
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.relevance.actionfilter.SearchActionFilter;
import org.opensearch.search.relevance.transformer.ResultTransformerType;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraClientSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.SearchConfigurationExtBuilder;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.KendraIntelligentRanker;
import org.opensearch.search.relevance.transformer.ResultTransformer;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

Expand All @@ -47,9 +49,13 @@ public class SearchRelevancePlugin extends Plugin implements ActionPlugin, Searc
private KendraHttpClient kendraClient;
private KendraIntelligentRanker kendraIntelligentRanker;

private Map<ResultTransformerType, ResultTransformer> getAllResultTransformers() {
private Collection<ResultTransformer> getAllResultTransformers() {
// Initialize and add other transformers here
return Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, this.kendraIntelligentRanker);
return List.of(this.kendraIntelligentRanker);
}

private Collection<ResultTransformerConfigurationFactory> getResultTransformerConfigurationFactories() {
return List.of(KendraIntelligentRankingConfigurationFactory.INSTANCE);
}

@Override
Expand Down Expand Up @@ -93,8 +99,12 @@ public Collection<Object> createComponents(

@Override
public List<SearchExtSpec<?>> getSearchExts() {
Map<String, ResultTransformerConfigurationFactory> resultTransformerMap = getResultTransformerConfigurationFactories().stream()
.collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i));
return Collections.singletonList(
new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME, SearchConfigurationExtBuilder::new, SearchConfigurationExtBuilder::parse));
new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME,
input -> new SearchConfigurationExtBuilder(input, resultTransformerMap),
parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,6 @@
*/
package org.opensearch.search.relevance.actionfilter;

import static org.opensearch.action.search.ShardSearchFailure.readShardSearchFailure;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand All @@ -28,7 +17,6 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilter;
import org.opensearch.action.support.ActionFilterChain;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand All @@ -43,27 +31,37 @@
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.relevance.configuration.ConfigurationUtils;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.ConfigurationUtils;
import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration;
import org.opensearch.search.relevance.transformer.ResultTransformer;
import org.opensearch.search.relevance.transformer.ResultTransformerType;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.tasks.Task;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public class SearchActionFilter implements ActionFilter {
private static final Logger logger = LogManager.getLogger(SearchActionFilter.class);

private final int order;

private final NamedWriteableRegistry namedWriteableRegistry;
private final Map<ResultTransformerType, ResultTransformer> supportedResultTransformers;
private final Map<String, ResultTransformer> resultTransformerMap;
private final OpenSearchClient openSearchClient;

public SearchActionFilter(Map<ResultTransformerType, ResultTransformer> supportedResultTransformers, OpenSearchClient openSearchClient) {
public SearchActionFilter(Collection<ResultTransformer> supportedResultTransformers,
OpenSearchClient openSearchClient) {
order = 10; // TODO: Finalize this value
namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
this.supportedResultTransformers = supportedResultTransformers;
resultTransformerMap = supportedResultTransformers.stream()
.collect(Collectors.toMap(t -> t.getConfigurationFactory().getName(), t -> t));
this.openSearchClient = openSearchClient;
}

Expand Down Expand Up @@ -109,12 +107,12 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
return;
}

List<ResultTransformerConfiguration> resultTransformerConfigurations = getResultTransformerConfigurations(indices[0],
searchRequest);
List<ResultTransformerConfiguration> resultTransformerConfigurations =
getResultTransformerConfigurations(indices[0], searchRequest);

LinkedHashMap<ResultTransformer, ResultTransformerConfiguration> orderedTransformersAndConfigs = new LinkedHashMap<>();
for (ResultTransformerConfiguration config : resultTransformerConfigurations) {
ResultTransformer resultTransformer = supportedResultTransformers.get(config.getType());
ResultTransformer resultTransformer = resultTransformerMap.get(config.getTransformerName());
// TODO: Should transformers make a decision based on the original request or the request they receive in the chain
if (resultTransformer.shouldTransform(searchRequest, config)) {
searchRequest = resultTransformer.preprocessRequest(searchRequest, config);
Expand Down Expand Up @@ -154,17 +152,15 @@ private List<ResultTransformerConfiguration> getResultTransformerConfigurations(
}

// Fetch all index settings for this plugin
String[] settingNames = supportedResultTransformers.values()
String[] settingNames = resultTransformerMap.values()
.stream()
.map(t -> t.getTransformerSettings()
.flatMap(t -> t.getTransformerSettings()
.stream()
.map(Setting::getKey)
.collect(Collectors.toList()))
.flatMap(Collection::stream)
.map(Setting::getKey))
.toArray(String[]::new);

configs = ConfigurationUtils.getResultTransformersFromIndexConfiguration(
openSearchClient.getIndexSettings(indexName, settingNames));
openSearchClient.getIndexSettings(indexName, settingNames), resultTransformerMap);

return configs;
}
Expand Down Expand Up @@ -194,21 +190,26 @@ public void onResponse(final Response response) {
final SearchResponse searchResponse = (SearchResponse) response;
final long totalHits = searchResponse.getHits().getTotalHits().value;
if (totalHits == 0) {
logger.info("TotalHits = 0. Returning search response without re-ranking.");
logger.info("TotalHits = 0. Returning search response without transforming.");
listener.onResponse(response);
return;
}

logger.debug("Starting re-ranking for search response: {}", searchResponse);
try {
final BytesStreamOutput out = new BytesStreamOutput();
searchResponse.writeTo(out);

final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry);

// Clone search hits (by serializing + deserializing) before transforming
final BytesStreamOutput out = new BytesStreamOutput();
searchResponse.getHits().writeTo(out);
final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(),
namedWriteableRegistry);
SearchHits hits = new SearchHits(in);

for (Map.Entry<ResultTransformer, ResultTransformerConfiguration> entry : orderedTransformersAndConfigs.entrySet()) {
long startTime = System.nanoTime();
hits = entry.getKey().transform(hits, searchRequest, entry.getValue());
long timeTookMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
logger.info(entry.getValue().getTransformerName() + ": took " + timeTookMillis + " ms");
}

List<SearchHit> searchHitsList = Arrays.asList(hits.getHits());
Expand All @@ -230,52 +231,22 @@ public void onResponse(final Response response) {
}
}

// TODO: How to handle SearchHits.TotalHits when transformer modifies the hit count
hits = new SearchHits(
searchHitsList.toArray(new SearchHit[0]),
hits.getTotalHits(),
hits.getMaxScore());

final InternalAggregations aggregations =
in.readBoolean() ? InternalAggregations.readFrom(in) : null;
final Suggest suggest = in.readBoolean() ? new Suggest(in) : null;
final boolean timedOut = in.readBoolean();
final Boolean terminatedEarly = in.readOptionalBoolean();
final SearchProfileShardResults profileResults = in.readOptionalWriteable(
SearchProfileShardResults::new);
final int numReducePhases = in.readVInt();

final SearchResponseSections internalResponse = new InternalSearchResponse(hits,
aggregations, suggest,
profileResults, timedOut, terminatedEarly, numReducePhases);

final int totalShards = in.readVInt();
final int successfulShards = in.readVInt();
final int shardSearchFailureSize = in.readVInt();
final ShardSearchFailure[] shardFailures;
if (shardSearchFailureSize == 0) {
shardFailures = ShardSearchFailure.EMPTY_ARRAY;
} else {
shardFailures = new ShardSearchFailure[shardSearchFailureSize];
for (int i = 0; i < shardFailures.length; i++) {
shardFailures[i] = readShardSearchFailure(in);
}
}

final SearchResponse.Clusters clusters = new SearchResponse.Clusters(in.readVInt(),
in.readVInt(), in.readVInt());
final String scrollId = in.readOptionalString();
final int skippedShards = in.readVInt();

final long tookInMillis = (System.nanoTime() - startTime) / 1000000;
final SearchResponse newResponse = new SearchResponse(internalResponse, scrollId,
totalShards, successfulShards,
skippedShards, tookInMillis, shardFailures, clusters);
(InternalAggregations) searchResponse.getAggregations(), searchResponse.getSuggest(),
new SearchProfileShardResults(searchResponse.getProfileResults()), searchResponse.isTimedOut(),
searchResponse.isTerminatedEarly(), searchResponse.getNumReducePhases());

final long tookInMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
final SearchResponse newResponse = new SearchResponse(internalResponse, searchResponse.getScrollId(),
searchResponse.getTotalShards(), searchResponse.getSuccessfulShards(),
searchResponse.getSkippedShards(), tookInMillis, searchResponse.getShardFailures(),
searchResponse.getClusters());
listener.onResponse((Response) newResponse);

// TODO: Change this to a metric
logger.info("Result transformer operations overhead time: {}ms",
tookInMillis - searchResponse.getTook().getMillis());
} catch (final Exception e) {
logger.error("Result transformer operations failed.", e);
throw new OpenSearchException("Result transformer operations failed.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
*/
package org.opensearch.search.relevance.configuration;

import static org.opensearch.search.relevance.configuration.Constants.RESULT_TRANSFORMER_SETTING_PREFIX;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.relevance.transformer.ResultTransformer;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.relevance.transformer.ResultTransformerType;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration;

import static org.opensearch.search.relevance.configuration.Constants.RESULT_TRANSFORMER_SETTING_PREFIX;

public class ConfigurationUtils {

Expand All @@ -27,16 +27,16 @@ public class ConfigurationUtils {
* @param settings all index settings configured for this plugin
* @return ordered and validated list of result transformers, empty list if not specified
*/
public static List<ResultTransformerConfiguration> getResultTransformersFromIndexConfiguration(
Settings settings) {
public static List<ResultTransformerConfiguration> getResultTransformersFromIndexConfiguration(Settings settings,
Map<String, ResultTransformer> resultTransformerMap) {
List<ResultTransformerConfiguration> indexLevelConfigs = new ArrayList<>();

if (settings != null) {
if (settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX) != null) {
for (Map.Entry<String, Settings> resultTransformer : settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX).entrySet()) {
ResultTransformerType resultTransformerType = ResultTransformerType.fromString(resultTransformer.getKey());
if (ResultTransformerType.KENDRA_INTELLIGENT_RANKING.equals(resultTransformerType)) {
indexLevelConfigs.add(new KendraIntelligentRankingConfiguration(resultTransformer.getValue()));
for (Map.Entry<String, Settings> transformerSettings : settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX).entrySet()) {
if (resultTransformerMap.containsKey(transformerSettings.getKey())) {
ResultTransformer transformer = resultTransformerMap.get(transformerSettings.getKey());
indexLevelConfigs.add(transformer.getConfigurationFactory().configure(transformerSettings.getValue()));
}
}
}
Expand Down Expand Up @@ -86,7 +86,7 @@ public static List<ResultTransformerConfiguration> reorderAndValidateConfigs(
for (int i = 0; i < configs.size(); ++i) {
if (configs.get(i).getOrder() != (i + 1)) {
throw new IllegalArgumentException("Expected order [" + (i + 1) + "] for transformer [" +
configs.get(i).getType() + "], but found [" + configs.get(i).getOrder() + "]");
configs.get(i).getTransformerName() + "], but found [" + configs.get(i).getOrder() + "]");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
*/
package org.opensearch.search.relevance.configuration;

import org.opensearch.search.relevance.transformer.ResultTransformerType;

public abstract class ResultTransformerConfiguration extends TransformerConfiguration {

public abstract ResultTransformerType getType();

public abstract String getTransformerName();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.configuration;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentParser;

import java.io.IOException;

public interface ResultTransformerConfigurationFactory {
String getName();

/**
* Build configuration based on index settings
* @param indexSettings a set of index settings under a group scoped based on this result transformer's name.
* @return a transformer configuration based on the passed settings.
*/
ResultTransformerConfiguration configure(Settings indexSettings);

/**
* Build configuration from serialized XContent, e.g. as part of a serialized {@link SearchConfigurationExtBuilder}.
* @param parser an XContentParser pointing to a node serialized from a {@link ResultTransformerConfiguration} of
* this type.
* @return a transformer configuration based on the parameters specified in the XContent.
*/
ResultTransformerConfiguration configure(XContentParser parser) throws IOException;

/**
* Build configuration from a serialized stream.
* @param streamInput a {@link org.opensearch.common.io.stream.Writeable} serialized representation of transformer
* configuration.
* @return configuration the deserialized transformer configuration.
*/
ResultTransformerConfiguration configure(StreamInput streamInput) throws IOException;

}

0 comments on commit d8ba75b

Please sign in to comment.