Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add validations for Personalize input and configurations #160

Merged
merged 1 commit into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.ValidationUtil;

import java.util.Map;
import java.util.concurrent.TimeUnit;
Expand All @@ -41,7 +42,7 @@ public class PersonalizeRankingResponseProcessor extends AbstractProcessor imple

private static final Logger logger = LogManager.getLogger(PersonalizeRankingResponseProcessor.class);

public static final String TYPE = "personalize_ranking";
public static final String TYPE = "personalized_search_ranking";
private final String tag;
private final String description;
private final PersonalizeClient personalizeClient;
Expand Down Expand Up @@ -163,6 +164,7 @@ public PersonalizeRankingResponseProcessor create(Map<String, Processor.Factory<

PersonalizeIntelligentRankerConfiguration rankerConfig =
new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, awsRegion, weight);
ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, tag);
AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(personalizeClientSettings, iamRoleArn, awsRegion);
PersonalizeClient personalizeClient = clientBuilder.apply(credentialsProvider, awsRegion);
return new PersonalizeRankingResponseProcessor(tag, description, ignoreFailure, rankerConfig, personalizeClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client;

import com.amazonaws.AmazonServiceException;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.personalizeruntime.AmazonPersonalizeRuntime;
import com.amazonaws.services.personalizeruntime.AmazonPersonalizeRuntimeClientBuilder;
Expand All @@ -24,17 +25,21 @@
*/
public class PersonalizeClient implements Closeable {
private final AmazonPersonalizeRuntime personalizeRuntime;
private static final String USER_AGENT_PREFIX = "PersonalizeOpenSearchPlugin";

/**
* Constructor for Amazon Personalize client
* @param credentialsProvider Credentials to be used for accessing Amazon Personalize
* @param awsRegion AWS region where Amazon Personalize campaign is hosted
*/
public PersonalizeClient(AWSCredentialsProvider credentialsProvider, String awsRegion) {
ClientConfiguration clientConfiguration = new ClientConfiguration()
.withUserAgentPrefix(USER_AGENT_PREFIX);
personalizeRuntime = AccessController.doPrivileged(
(PrivilegedAction<AmazonPersonalizeRuntime>) () -> AmazonPersonalizeRuntimeClientBuilder.standard()
.withCredentials(credentialsProvider)
.withRegion(awsRegion)
.withClientConfiguration(clientConfiguration)
.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ public final class PersonalizeClientSettings {
/**
* The access key (ie login id) for connecting to Personalize.
*/
public static final Setting<SecureString> ACCESS_KEY_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.access_key", null);
public static final Setting<SecureString> ACCESS_KEY_SETTING = SecureSetting.secureString("personalized_search_ranking.aws.access_key", null);

/**
* The secret key (ie password) for connecting to Personalize.
*/
public static final Setting<SecureString> SECRET_KEY_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.secret_key", null);
public static final Setting<SecureString> SECRET_KEY_SETTING = SecureSetting.secureString("personalized_search_ranking.aws.secret_key", null);

/**
* The session token for connecting to Personalize.
*/
public static final Setting<SecureString> SESSION_TOKEN_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.session_token", null);
public static final Setting<SecureString> SESSION_TOKEN_SETTING = SecureSetting.secureString("personalized_search_ranking.aws.session_token", null);

private final AWSCredentials credentials;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,4 @@ public interface PersonalizedRanker {
* @return Re ranked search hits
*/
SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters);

/**
* Validate Personalize configuration for calling Personalize service
* @param requestParameters Request parameters for Personalize present in search request
* @return True if valid configuration present else false.
*/
boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters);

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import com.amazonaws.services.personalizeruntime.model.PredictedItem;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TotalHits;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
Expand All @@ -23,9 +24,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.LinkedList;
import java.util.Map;
import java.util.stream.Collectors;

Expand All @@ -51,15 +51,17 @@ public AmazonPersonalizedRankerImpl(PersonalizeIntelligentRankerConfiguration co
@Override
public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters) {
try {
if (!isValidPersonalizeConfigPresent(requestParameters)) {
throw new IllegalArgumentException("Required configurations missing from Personalize " +
"response processor configuration or search request parameters");
}
validatePersonalizeRequestParams(requestParameters);
List<SearchHit> originalHits = Arrays.asList(hits.getHits());
// Do not make Personalize call if weight is zero which implies Personalization is turned off.
if (rankerConfig.getWeight() == 0) {
logger.info("Not applying Personalized ranking. Given value for weight configuration: {}", rankerConfig.getWeight());
return hits;
}
String itemIdfield = rankerConfig.getItemIdField();
List<String> documentIdsToRank;
// If item field is not specified in the configuration then use default _id field.
if (!itemIdfield.isEmpty()) {
if (itemIdfield != null && !itemIdfield.isBlank()) {
documentIdsToRank = originalHits.stream()
.filter(h -> h.getSourceAsMap().get(itemIdfield) != null)
.map(h -> h.getSourceAsMap().get(itemIdfield).toString())
Expand All @@ -70,13 +72,17 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa
.map(h -> h.getId())
.collect(Collectors.toList());
}
if (documentIdsToRank.size() == 0) {
throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "item_id_field",
"no item ids found to apply Personalized reranking. Please check configured value for item_id_field");
}
logger.info("Document Ids to re-rank with Personalize: {}", Arrays.toString(documentIdsToRank.toArray()));
String userId = requestParameters.getUserId();
Map<String, String> context = requestParameters.getContext() != null ?
requestParameters.getContext().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> isValidPersonalizeContext(e)))
.collect(Collectors.toMap(Map.Entry::getKey, e -> (String)e.getValue()))
: null;
logger.info("User ID from request parameters. User ID: {}", userId);
logger.info("User ID from personalize request parameters - User ID: {}", userId);
if (context != null && !context.isEmpty()) {
logger.info("Personalize context provided in the search request");
}
Expand All @@ -88,109 +94,67 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa
.withUserId(userId);
GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest);

List<PredictedItem> personalizeRrankingResult = result.getPersonalizedRanking();
Map<String, Float> idToPersonalizeRankingScoreMap = new HashMap<>();
Map<String, Float> idToOpenSearchScoreMap = new HashMap<>();
Map<String, SearchHit> itemIdToSearchHitMap = new HashMap<>();
// Build a map with key as item id and value as personalize ranking score
for (PredictedItem item : personalizeRrankingResult) {
idToPersonalizeRankingScoreMap.put(item.getItemId(), item.getScore().floatValue());
}

// Build a map with key as item id and value as open search scores and another map
// with key as item id and value as corresponding search hit
for (SearchHit hit : originalHits) {
if (!itemIdfield.isEmpty()){
idToOpenSearchScoreMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit.getScore());
itemIdToSearchHitMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit);
}
else{
idToOpenSearchScoreMap.put(hit.getId(), hit.getScore());
itemIdToSearchHitMap.put(hit.getId(), hit);
}
}


float weight = (float) rankerConfig.getWeight();
SearchHits newHits = combineScores(idToPersonalizeRankingScoreMap, idToOpenSearchScoreMap,
itemIdToSearchHitMap, hits.getTotalHits(), weight);
return newHits;
SearchHits personalizedHits = combineScores(hits, result);
return personalizedHits;
} catch (Exception ex) {
logger.error("Failed to re rank with Personalize. Returning original search results without Personalize re ranking.", ex);
return hits;
logger.error("Failed to re rank with Personalize.", ex);
throw ex;
}
}

//Combine open search hits and personalize campaign response
public SearchHits combineScores(Map<String, Float> idToPersonalizeRankingScoreMap,
Map<String, Float> idToOpenSearchScoreMap,
Map<String, SearchHit> itemIdToSearchHitMap,
TotalHits totalHits, float weight) {
//Update open search score based on the personalize campaign response for each item id
List<String> openSearchItemId = new ArrayList<String>(idToOpenSearchScoreMap.keySet());
for (String itemId : openSearchItemId) {
if(idToPersonalizeRankingScoreMap.containsKey(itemId)){
float personalizedScore = idToPersonalizeRankingScoreMap.get(itemId);
float openSearchScore = idToOpenSearchScoreMap.get(itemId);
float combinedScore = (float) (weight / Math.log(openSearchScore + 1)
+ (1 - weight) / Math.log(personalizedScore + 1));
idToOpenSearchScoreMap.put(itemId, combinedScore);
private SearchHits combineScores(SearchHits originalHits, GetPersonalizedRankingResult personalizedRankingResult) {
List<PredictedItem> personalziedRanking = personalizedRankingResult.getPersonalizedRanking();
List<String> personalizedRankedItemsList = new LinkedList<>();
for (PredictedItem item : personalziedRanking) {
personalizedRankedItemsList.add(item.getItemId());
}
int totalHits = originalHits.getHits().length;
List<SearchHit> rerankedHits = new ArrayList<>(totalHits);
float maxScore = 0f;
double weight = rankerConfig.getWeight();
for (int i = 0 ; i < totalHits ; i++) {
String openSearchItemId;
SearchHit hit = originalHits.getAt(i);
String itemIdField = rankerConfig.getItemIdField();
if (itemIdField != null && !(itemIdField.isBlank())) {
openSearchItemId = hit.getSourceAsMap().get(rankerConfig.getItemIdField()).toString();
} else {
openSearchItemId = hit.getId();
}
int openSearchRank = i + 1;
int personalizedRank = personalizedRankedItemsList.indexOf(openSearchItemId) + 1;
float combinedScore = (float) (((1- weight) / (Math.log(openSearchRank + 1) / Math.log(2)))
+ ((weight) / (Math.log(personalizedRank + 1) / Math.log(2))));
maxScore = Math.max(maxScore, combinedScore);
hit.score(combinedScore);
rerankedHits.add(hit);
}

//Create a new list of search hits in the decreasing order of the combined scores
Map<String, Float> sortedScores = sortByValue(idToOpenSearchScoreMap);

List<SearchHit> rerankedHits = sortedScores.keySet().stream()
.map(itemId -> {
SearchHit hit = itemIdToSearchHitMap.get(itemId);
hit.score(sortedScores.get(itemId));
return hit;
})
.collect(Collectors.toList());
float maxScore = sortedScores.values().stream().max(Float::compare).orElse(0f);
return new SearchHits(rerankedHits.toArray(new SearchHit[0]), totalHits, maxScore);
}


//Sort map by reverse order of the values
public Map<String, Float> sortByValue(Map<String, Float> map) {
return map.entrySet().stream()
.sorted(Map.Entry.<String, Float>comparingByValue().reversed())
.collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(oldValue, newValue) -> oldValue, LinkedHashMap::new));
rerankedHits.sort(Comparator.comparing(SearchHit::getScore).reversed());
return new SearchHits(rerankedHits.toArray(new SearchHit[0]), originalHits.getTotalHits(), maxScore);
}


/**
* Validate Personalize configuration for calling Personalize service
* @param requestParameters Request parameters for Personalize present in search request
* @return True if valid configuration present else false.
*/
public boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters) {
boolean isValidPersonalizeConfig = true;

if (requestParameters == null || requestParameters.getUserId().isEmpty()) {
isValidPersonalizeConfig = false;
logger.error("Required Personalize parameters are not provided in the search request");
private void validatePersonalizeRequestParams(PersonalizeRequestParameters requestParameters) {
if (requestParameters == null || requestParameters.getUserId() == null || requestParameters.getUserId().isBlank()) {
throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "user_id",
"required Personalize request parameter is missing");
}

if (rankerConfig == null || rankerConfig.getPersonalizeCampaign().isEmpty() ||
rankerConfig.getWeight() < 0.0 || rankerConfig.getWeight() > 1.0) {
isValidPersonalizeConfig = false;
logger.error("Required Personalized ranker configuration is missing");
if (requestParameters.getContext() != null) {
try {
requestParameters.getContext().entrySet().stream().forEach(e -> isValidPersonalizeContext(e));
} catch (IllegalArgumentException iae) {
throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "context", iae.getMessage());
}
}
return isValidPersonalizeConfig;
}

private String isValidPersonalizeContext(Map.Entry<String, Object> contextEntry) throws IllegalArgumentException {
if (contextEntry.getValue() instanceof String) {
return (String) contextEntry.getValue();
} else {
throw new IllegalArgumentException("Personalize context value is not of type String. " +
"Invalid context value: " + contextEntry.getValue());
private void isValidPersonalizeContext(Map.Entry<String, Object> contextEntry) throws IllegalArgumentException {
if (!(contextEntry.getValue() instanceof String)) {
throw new IllegalArgumentException("Personalize context value is not of type String. Invalid context value: " + contextEntry.getValue());
}
}
}
Loading
Loading