Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes and refactoring #38

Merged
merged 3 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.opensearch.action.support.ActionFilter;
import org.opensearch.client.Client;
Expand All @@ -30,14 +31,15 @@
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.relevance.actionfilter.SearchActionFilter;
import org.opensearch.search.relevance.transformer.ResultTransformerType;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraClientSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.SearchConfigurationExtBuilder;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.KendraIntelligentRanker;
import org.opensearch.search.relevance.transformer.ResultTransformer;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

Expand All @@ -47,9 +49,13 @@ public class SearchRelevancePlugin extends Plugin implements ActionPlugin, Searc
private KendraHttpClient kendraClient;
private KendraIntelligentRanker kendraIntelligentRanker;

private Map<ResultTransformerType, ResultTransformer> getAllResultTransformers() {
private Collection<ResultTransformer> getAllResultTransformers() {
// Initialize and add other transformers here
return Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, this.kendraIntelligentRanker);
return List.of(this.kendraIntelligentRanker);
}

private Collection<ResultTransformerConfigurationFactory> getResultTransformerConfigurationFactories() {
return List.of(KendraIntelligentRankingConfigurationFactory.INSTANCE);
}

@Override
Expand Down Expand Up @@ -93,8 +99,12 @@ public Collection<Object> createComponents(

@Override
public List<SearchExtSpec<?>> getSearchExts() {
Map<String, ResultTransformerConfigurationFactory> resultTransformerMap = getResultTransformerConfigurationFactories().stream()
.collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i));
return Collections.singletonList(
new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME, SearchConfigurationExtBuilder::new, SearchConfigurationExtBuilder::parse));
new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME,
input -> new SearchConfigurationExtBuilder(input, resultTransformerMap),
parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,6 @@
*/
package org.opensearch.search.relevance.actionfilter;

import static org.opensearch.action.search.ShardSearchFailure.readShardSearchFailure;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand All @@ -28,7 +17,6 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilter;
import org.opensearch.action.support.ActionFilterChain;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand All @@ -43,27 +31,37 @@
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.relevance.configuration.ConfigurationUtils;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.ConfigurationUtils;
import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration;
import org.opensearch.search.relevance.transformer.ResultTransformer;
import org.opensearch.search.relevance.transformer.ResultTransformerType;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.tasks.Task;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public class SearchActionFilter implements ActionFilter {
private static final Logger logger = LogManager.getLogger(SearchActionFilter.class);

private final int order;

private final NamedWriteableRegistry namedWriteableRegistry;
private final Map<ResultTransformerType, ResultTransformer> supportedResultTransformers;
private final Map<String, ResultTransformer> resultTransformerMap;
private final OpenSearchClient openSearchClient;

public SearchActionFilter(Map<ResultTransformerType, ResultTransformer> supportedResultTransformers, OpenSearchClient openSearchClient) {
public SearchActionFilter(Collection<ResultTransformer> supportedResultTransformers,
OpenSearchClient openSearchClient) {
order = 10; // TODO: Finalize this value
namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
this.supportedResultTransformers = supportedResultTransformers;
resultTransformerMap = supportedResultTransformers.stream()
.collect(Collectors.toMap(t -> t.getConfigurationFactory().getName(), t -> t));
this.openSearchClient = openSearchClient;
}

Expand Down Expand Up @@ -109,12 +107,12 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
return;
}

List<ResultTransformerConfiguration> resultTransformerConfigurations = getResultTransformerConfigurations(indices[0],
searchRequest);
List<ResultTransformerConfiguration> resultTransformerConfigurations =
getResultTransformerConfigurations(indices[0], searchRequest);

LinkedHashMap<ResultTransformer, ResultTransformerConfiguration> orderedTransformersAndConfigs = new LinkedHashMap<>();
for (ResultTransformerConfiguration config : resultTransformerConfigurations) {
ResultTransformer resultTransformer = supportedResultTransformers.get(config.getType());
ResultTransformer resultTransformer = resultTransformerMap.get(config.getTransformerName());
// TODO: Should transformers make a decision based on the original request or the request they receive in the chain
if (resultTransformer.shouldTransform(searchRequest, config)) {
searchRequest = resultTransformer.preprocessRequest(searchRequest, config);
Expand Down Expand Up @@ -154,17 +152,15 @@ private List<ResultTransformerConfiguration> getResultTransformerConfigurations(
}

// Fetch all index settings for this plugin
String[] settingNames = supportedResultTransformers.values()
String[] settingNames = resultTransformerMap.values()
.stream()
.map(t -> t.getTransformerSettings()
.flatMap(t -> t.getTransformerSettings()
.stream()
.map(Setting::getKey)
.collect(Collectors.toList()))
.flatMap(Collection::stream)
.map(Setting::getKey))
.toArray(String[]::new);

configs = ConfigurationUtils.getResultTransformersFromIndexConfiguration(
openSearchClient.getIndexSettings(indexName, settingNames));
openSearchClient.getIndexSettings(indexName, settingNames), resultTransformerMap);

return configs;
}
Expand Down Expand Up @@ -194,21 +190,26 @@ public void onResponse(final Response response) {
final SearchResponse searchResponse = (SearchResponse) response;
final long totalHits = searchResponse.getHits().getTotalHits().value;
if (totalHits == 0) {
logger.info("TotalHits = 0. Returning search response without re-ranking.");
logger.info("TotalHits = 0. Returning search response without transforming.");
listener.onResponse(response);
return;
}

logger.debug("Starting re-ranking for search response: {}", searchResponse);
try {
final BytesStreamOutput out = new BytesStreamOutput();
searchResponse.writeTo(out);

final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry);

// Clone search hits (by serializing + deserializing) before transforming
final BytesStreamOutput out = new BytesStreamOutput();
searchResponse.getHits().writeTo(out);
final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(),
namedWriteableRegistry);
SearchHits hits = new SearchHits(in);

for (Map.Entry<ResultTransformer, ResultTransformerConfiguration> entry : orderedTransformersAndConfigs.entrySet()) {
long startTime = System.nanoTime();
hits = entry.getKey().transform(hits, searchRequest, entry.getValue());
long timeTookMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
logger.info(entry.getValue().getTransformerName() + ": took " + timeTookMillis + " ms");
}

List<SearchHit> searchHitsList = Arrays.asList(hits.getHits());
Expand All @@ -230,52 +231,22 @@ public void onResponse(final Response response) {
}
}

// TODO: How to handle SearchHits.TotalHits when transformer modifies the hit count
hits = new SearchHits(
searchHitsList.toArray(new SearchHit[0]),
hits.getTotalHits(),
hits.getMaxScore());

final InternalAggregations aggregations =
in.readBoolean() ? InternalAggregations.readFrom(in) : null;
final Suggest suggest = in.readBoolean() ? new Suggest(in) : null;
final boolean timedOut = in.readBoolean();
final Boolean terminatedEarly = in.readOptionalBoolean();
final SearchProfileShardResults profileResults = in.readOptionalWriteable(
SearchProfileShardResults::new);
final int numReducePhases = in.readVInt();

final SearchResponseSections internalResponse = new InternalSearchResponse(hits,
aggregations, suggest,
profileResults, timedOut, terminatedEarly, numReducePhases);

final int totalShards = in.readVInt();
final int successfulShards = in.readVInt();
final int shardSearchFailureSize = in.readVInt();
final ShardSearchFailure[] shardFailures;
if (shardSearchFailureSize == 0) {
shardFailures = ShardSearchFailure.EMPTY_ARRAY;
} else {
shardFailures = new ShardSearchFailure[shardSearchFailureSize];
for (int i = 0; i < shardFailures.length; i++) {
shardFailures[i] = readShardSearchFailure(in);
}
}

final SearchResponse.Clusters clusters = new SearchResponse.Clusters(in.readVInt(),
in.readVInt(), in.readVInt());
final String scrollId = in.readOptionalString();
final int skippedShards = in.readVInt();

final long tookInMillis = (System.nanoTime() - startTime) / 1000000;
final SearchResponse newResponse = new SearchResponse(internalResponse, scrollId,
totalShards, successfulShards,
skippedShards, tookInMillis, shardFailures, clusters);
(InternalAggregations) searchResponse.getAggregations(), searchResponse.getSuggest(),
new SearchProfileShardResults(searchResponse.getProfileResults()), searchResponse.isTimedOut(),
searchResponse.isTerminatedEarly(), searchResponse.getNumReducePhases());

final long tookInMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
final SearchResponse newResponse = new SearchResponse(internalResponse, searchResponse.getScrollId(),
searchResponse.getTotalShards(), searchResponse.getSuccessfulShards(),
searchResponse.getSkippedShards(), tookInMillis, searchResponse.getShardFailures(),
searchResponse.getClusters());
listener.onResponse((Response) newResponse);

// TODO: Change this to a metric
logger.info("Result transformer operations overhead time: {}ms",
tookInMillis - searchResponse.getTook().getMillis());
macohen marked this conversation as resolved.
Show resolved Hide resolved
} catch (final Exception e) {
logger.error("Result transformer operations failed.", e);
throw new OpenSearchException("Result transformer operations failed.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
*/
package org.opensearch.search.relevance.configuration;

import static org.opensearch.search.relevance.configuration.Constants.RESULT_TRANSFORMER_SETTING_PREFIX;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.relevance.transformer.ResultTransformer;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.relevance.transformer.ResultTransformerType;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration;

import static org.opensearch.search.relevance.configuration.Constants.RESULT_TRANSFORMER_SETTING_PREFIX;

public class ConfigurationUtils {

Expand All @@ -27,16 +27,16 @@ public class ConfigurationUtils {
* @param settings all index settings configured for this plugin
* @return ordered and validated list of result transformers, empty list if not specified
*/
public static List<ResultTransformerConfiguration> getResultTransformersFromIndexConfiguration(
Settings settings) {
public static List<ResultTransformerConfiguration> getResultTransformersFromIndexConfiguration(Settings settings,
Map<String, ResultTransformer> resultTransformerMap) {
List<ResultTransformerConfiguration> indexLevelConfigs = new ArrayList<>();

if (settings != null) {
if (settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX) != null) {
for (Map.Entry<String, Settings> resultTransformer : settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX).entrySet()) {
ResultTransformerType resultTransformerType = ResultTransformerType.fromString(resultTransformer.getKey());
if (ResultTransformerType.KENDRA_INTELLIGENT_RANKING.equals(resultTransformerType)) {
indexLevelConfigs.add(new KendraIntelligentRankingConfiguration(resultTransformer.getValue()));
for (Map.Entry<String, Settings> transformerSettings : settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX).entrySet()) {
if (resultTransformerMap.containsKey(transformerSettings.getKey())) {
ResultTransformer transformer = resultTransformerMap.get(transformerSettings.getKey());
indexLevelConfigs.add(transformer.getConfigurationFactory().configure(transformerSettings.getValue()));
}
}
}
Expand Down Expand Up @@ -86,7 +86,7 @@ public static List<ResultTransformerConfiguration> reorderAndValidateConfigs(
for (int i = 0; i < configs.size(); ++i) {
if (configs.get(i).getOrder() != (i + 1)) {
throw new IllegalArgumentException("Expected order [" + (i + 1) + "] for transformer [" +
configs.get(i).getType() + "], but found [" + configs.get(i).getOrder() + "]");
configs.get(i).getTransformerName() + "], but found [" + configs.get(i).getOrder() + "]");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
*/
package org.opensearch.search.relevance.configuration;

import org.opensearch.search.relevance.transformer.ResultTransformerType;

public abstract class ResultTransformerConfiguration extends TransformerConfiguration {

public abstract ResultTransformerType getType();

public abstract String getTransformerName();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.configuration;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentParser;

import java.io.IOException;

public interface ResultTransformerConfigurationFactory {
String getName();

/**
* Build configuration based on index settings
* @param indexSettings a set of index settings under a group scoped based on this result transformer's name.
* @return a transformer configuration based on the passed settings.
*/
ResultTransformerConfiguration configure(Settings indexSettings);

/**
* Build configuration from serialized XContent, e.g. as part of a serialized {@link SearchConfigurationExtBuilder}.
* @param parser an XContentParser pointing to a node serialized from a {@link ResultTransformerConfiguration} of
* this type.
* @return a transformer configuration based on the parameters specified in the XContent.
*/
ResultTransformerConfiguration configure(XContentParser parser) throws IOException;

/**
* Build configuration from a serialized stream.
* @param streamInput a {@link org.opensearch.common.io.stream.Writeable} serialized representation of transformer
* configuration.
* @return configuration the deserialized transformer configuration.
*/
ResultTransformerConfiguration configure(StreamInput streamInput) throws IOException;

}