Skip to content

Commit

Permalink
Support contextual metadata to use when getting personalized reranking (
Browse files Browse the repository at this point in the history
#144) (#145)

This change supports using contextual metadata for personalized reranking if provided as a search request parameter
Change ensures appropriate exception is thrown when context parameter values are not of type String.
Added unit tests to improve test coverage.
Change validation and exception handling for personalize context

---------

Signed-off-by: Ketan Kulkarni <kektnr@amazon.com>
(cherry picked from commit 55a702c)

Co-authored-by: kulket <130191298+kulket@users.noreply.github.com>
  • Loading branch information
opensearch-trigger-bot[bot] and kulket committed May 31, 2023
1 parent 2621536 commit 7285598
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 34 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,45 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.PERSONALIZE_REQUEST_PARAMETERS;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.USER_ID_PARAMETER;

public class PersonalizeRequestParameters implements Writeable, ToXContentObject {

static final String PERSONALIZE_REQUEST_PARAMETERS = "personalize_request_parameters";
private static final String USER_ID_PARAMETER = "user_id";
private static final String CONTEXT_PARAMETER = "context";

private static final ObjectParser<PersonalizeRequestParameters, Void> PARSER;
private static final ParseField USER_ID = new ParseField(USER_ID_PARAMETER);
private static final ParseField CONTEXT = new ParseField(CONTEXT_PARAMETER);

static {
PARSER = new ObjectParser<>(PERSONALIZE_REQUEST_PARAMETERS, PersonalizeRequestParameters::new);
PARSER.declareString(PersonalizeRequestParameters::setUserId, USER_ID);
PARSER.declareObject(PersonalizeRequestParameters::setContext,(XContentParser p, Void c) -> {
try {
return p.map();
} catch (IOException e) {
throw new IllegalArgumentException("Error parsing Personalize context from request parameters", e);
}
}, CONTEXT);
}

private String userId;

private Map<String, Object> context;

public PersonalizeRequestParameters() {}

public PersonalizeRequestParameters(String userId) {
public PersonalizeRequestParameters(String userId, Map<String, Object> context) {
this.userId = userId;
this.context = context;
}

public PersonalizeRequestParameters(StreamInput input) throws IOException {
this.userId = input.readString();
this.context = input.readMap();
}

public String getUserId() {
Expand All @@ -52,14 +66,24 @@ public void setUserId(String userId) {
this.userId = userId;
}

public Map<String, Object> getContext() {
return context;
}

public void setContext(Map<String, Object> context) {
this.context = context;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.userId);
out.writeMap(this.context);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.field(USER_ID.getPreferredName(), this.userId);
builder.field(USER_ID.getPreferredName(), this.userId);
return builder.field(CONTEXT.getPreferredName(), this.context);
}

public static PersonalizeRequestParameters parse(XContentParser parser) throws IOException {
Expand All @@ -75,11 +99,12 @@ public boolean equals(Object o) {
PersonalizeRequestParameters config = (PersonalizeRequestParameters) o;

if (!userId.equals(config.userId)) return false;
return userId.equals(config.userId);
if (context.size() != config.getContext().size()) return false;
return userId.equals(config.userId) && context.equals(config.getContext());
}

@Override
public int hashCode() {
return Objects.hash(userId);
return Objects.hash(userId, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import java.io.IOException;
import java.util.Objects;

import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.PERSONALIZE_REQUEST_PARAMETERS;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters.PERSONALIZE_REQUEST_PARAMETERS;

public class PersonalizeRequestParametersExtBuilder extends SearchExtBuilder {
private static final Logger logger = LogManager.getLogger(PersonalizeRequestParametersExtBuilder.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@

import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest;
import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult;
import com.amazonaws.services.personalizeruntime.model.PredictedItem;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -53,7 +52,7 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa
List<SearchHit> originalHits = Arrays.asList(hits.getHits());
String itemIdfield = rankerConfig.getItemIdField();
List<String> documentIdsToRank;
// If item field is not specified in the configruation then use default _id field.
// If item field is not specified in the configuration then use default _id field.
if (!itemIdfield.isEmpty()) {
documentIdsToRank = originalHits.stream()
.filter(h -> h.getSourceAsMap().get(itemIdfield) != null)
Expand All @@ -66,12 +65,20 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa
.collect(Collectors.toList());
}
logger.info("Document Ids to re-rank with Personalize: {}", Arrays.toString(documentIdsToRank.toArray()));
// TODO: Parse context from request parameters
String userId = requestParameters.getUserId();
Map<String, String> context = requestParameters.getContext() != null ?
requestParameters.getContext().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> isValidPersonalizeContext(e)))
: null;
logger.info("User ID from request parameters. User ID: {}", userId);
if (context != null && !context.isEmpty()) {
logger.info("Personalize context provided in the search request");
}

GetPersonalizedRankingRequest personalizeRequest = new GetPersonalizedRankingRequest()
.withCampaignArn(rankerConfig.getPersonalizeCampaign())
.withInputList(documentIdsToRank)
.withContext(context)
.withUserId(userId);
GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest);

Expand Down Expand Up @@ -103,4 +110,13 @@ public boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requ
}
return isValidPersonalizeConfig;
}

private String isValidPersonalizeContext(Map.Entry<String, Object> contextEntry) throws IllegalArgumentException {
if (contextEntry.getValue() instanceof String) {
return (String) contextEntry.getValue();
} else {
throw new IllegalArgumentException("Personalize context value is not of type String. " +
"Invalid context value: " + contextEntry.getValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
import org.opensearch.env.TestEnvironment;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME;

public class PersonalizeResponseProcessorTests extends OpenSearchTestCase {
Expand All @@ -42,7 +42,7 @@ public class PersonalizeResponseProcessorTests extends OpenSearchTestCase {
private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign";
private String iamRoleArn = "";
private String recipe = "sample-personalize-recipe";
private String itemIdField = "ITEM_ID";
private String itemIdField = "";
private String region = "us-west-2";
private double weight = 0.25;

Expand Down Expand Up @@ -131,4 +131,82 @@ public void testProcessorWithHits() throws Exception {

personalizeResponseProcessor.processResponse(searchRequest, searchResponse);
}

public void testProcessorWithHitsAndSearchProcessorExt() throws Exception {
PersonalizeClient mockClient = mock(PersonalizeClient.class);

PersonalizeRankingResponseProcessor.Factory factory
= new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient);

Map<String, Object> configuration = new HashMap<>();
configuration.put("campaign_arn", personalizeCampaign);
configuration.put("item_id_field", itemIdField);
configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME);
configuration.put("weight", String.valueOf(weight));
configuration.put("iam_role_arn", iamRoleArn);
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);

Map<String, Object> personalizeContext = new HashMap<>();
personalizeContext.put("contextKey2", "contextValue2");
PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext);
PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder();
extBuilder.setRequestParameters(personalizeRequestParams);

SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource()
.ext(List.of(extBuilder));

SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);
SearchHit[] searchHits = new SearchHit[10];
for (int i = 0; i < searchHits.length; i++) {
searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap());
searchHits[i].score(1.0f);
}
SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f);
SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0);
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null);

personalizeResponseProcessor.processResponse(searchRequest, searchResponse);
}

public void testProcessorWithHitsWithInvalidPersonalizeContext() throws Exception {
PersonalizeClient mockClient = mock(PersonalizeClient.class);

PersonalizeRankingResponseProcessor.Factory factory
= new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient);

Map<String, Object> configuration = new HashMap<>();
configuration.put("campaign_arn", personalizeCampaign);
configuration.put("item_id_field", itemIdField);
configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME);
configuration.put("weight", String.valueOf(weight));
configuration.put("iam_role_arn", iamRoleArn);
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);

Map<String, Object> personalizeContext = new HashMap<>();
personalizeContext.put("contextKey2", 5);
PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext);
PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder();
extBuilder.setRequestParameters(personalizeRequestParams);

SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource()
.ext(List.of(extBuilder));

SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);
SearchHit[] searchHits = new SearchHit[10];
for (int i = 0; i < searchHits.length; i++) {
searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap());
searchHits[i].score(1.0f);
}
SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f);
SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0);
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null);

personalizeResponseProcessor.processResponse(searchRequest, searchResponse);
}
}
Loading

0 comments on commit 7285598

Please sign in to comment.