From 7285598658e86518bde5dd72869080eadd20bd1f Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 31 May 2023 11:11:18 -0700 Subject: [PATCH] Support contextual metadata to use when getting personalized reranking (#144) (#145) This change supports using contextual metadata for personalized reranking if provided as a search request parameter Change ensures appropriate exception is thrown when context parameter values are not of type String. Added unit tests to improve test coverage. Change validation and exception handling for personalize context --------- Signed-off-by: Ketan Kulkarni (cherry picked from commit 55a702c3fde043ff06c34c96e3f304cbd3270490) Co-authored-by: kulket <130191298+kulket@users.noreply.github.com> --- .../requestparameter/Constants.java | 15 ---- .../PersonalizeRequestParameters.java | 39 ++++++-- ...ersonalizeRequestParametersExtBuilder.java | 2 +- .../impl/AmazonPersonalizedRankerImpl.java | 24 ++++- .../PersonalizeResponseProcessorTests.java | 88 +++++++++++++++++-- .../AmazonPersonalizeRankerImplTests.java | 81 +++++++++++++++++ .../PersonalizeRequestParameterUtilTests.java | 69 ++++++++++++++- ...alizeRequestParametersExtBuilderTests.java | 7 +- 8 files changed, 291 insertions(+), 34 deletions(-) delete mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java deleted file mode 100644 index 7d959cf..0000000 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; - -public class Constants { - - public static final String PERSONALIZE_REQUEST_PARAMETERS = "personalize_request_parameters"; - public static final String USER_ID_PARAMETER = "user_id"; - -} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java index a2130b0..52cb3c3 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java @@ -17,31 +17,45 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; import java.util.Objects; -import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.PERSONALIZE_REQUEST_PARAMETERS; -import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.USER_ID_PARAMETER; - public class PersonalizeRequestParameters implements Writeable, ToXContentObject { + static final String PERSONALIZE_REQUEST_PARAMETERS = "personalize_request_parameters"; + private static final String USER_ID_PARAMETER = "user_id"; + private static final String CONTEXT_PARAMETER = "context"; + private static final ObjectParser PARSER; private static final ParseField USER_ID = new ParseField(USER_ID_PARAMETER); + private static final ParseField CONTEXT = new ParseField(CONTEXT_PARAMETER); static { PARSER = new ObjectParser<>(PERSONALIZE_REQUEST_PARAMETERS, PersonalizeRequestParameters::new); PARSER.declareString(PersonalizeRequestParameters::setUserId, USER_ID); + PARSER.declareObject(PersonalizeRequestParameters::setContext,(XContentParser p, Void c) -> { + try { + return p.map(); + } catch (IOException e) { + throw new IllegalArgumentException("Error parsing Personalize context from request parameters", e); + } + }, CONTEXT); } private String userId; + private Map context; + public PersonalizeRequestParameters() {} - public PersonalizeRequestParameters(String userId) { + public PersonalizeRequestParameters(String userId, Map context) { this.userId = userId; + this.context = context; } public PersonalizeRequestParameters(StreamInput input) throws IOException { this.userId = input.readString(); + this.context = input.readMap(); } public String getUserId() { @@ -52,14 +66,24 @@ public void setUserId(String userId) { this.userId = userId; } + public Map getContext() { + return context; + } + + public void setContext(Map context) { + this.context = context; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(this.userId); + out.writeMap(this.context); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.field(USER_ID.getPreferredName(), this.userId); + builder.field(USER_ID.getPreferredName(), this.userId); + return builder.field(CONTEXT.getPreferredName(), this.context); } public static PersonalizeRequestParameters parse(XContentParser parser) throws IOException { @@ -75,11 +99,12 @@ public boolean equals(Object o) { PersonalizeRequestParameters config = (PersonalizeRequestParameters) o; if (!userId.equals(config.userId)) return false; - return userId.equals(config.userId); + if (context.size() != config.getContext().size()) return false; + return userId.equals(config.userId) && context.equals(config.getContext()); } @Override public int hashCode() { - return Objects.hash(userId); + return Objects.hash(userId, context); } } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java index 0dc4f04..ecd567b 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java @@ -18,7 +18,7 @@ import java.io.IOException; import java.util.Objects; -import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.PERSONALIZE_REQUEST_PARAMETERS; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters.PERSONALIZE_REQUEST_PARAMETERS; public class PersonalizeRequestParametersExtBuilder extends SearchExtBuilder { private static final Logger logger = LogManager.getLogger(PersonalizeRequestParametersExtBuilder.class); diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java index bd3a964..ba7ade5 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java @@ -9,12 +9,10 @@ import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; -import com.amazonaws.services.personalizeruntime.model.PredictedItem; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; @@ -22,6 +20,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; /** @@ -53,7 +52,7 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa List originalHits = Arrays.asList(hits.getHits()); String itemIdfield = rankerConfig.getItemIdField(); List documentIdsToRank; - // If item field is not specified in the configruation then use default _id field. + // If item field is not specified in the configuration then use default _id field. if (!itemIdfield.isEmpty()) { documentIdsToRank = originalHits.stream() .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) @@ -66,12 +65,20 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa .collect(Collectors.toList()); } logger.info("Document Ids to re-rank with Personalize: {}", Arrays.toString(documentIdsToRank.toArray())); - // TODO: Parse context from request parameters String userId = requestParameters.getUserId(); + Map context = requestParameters.getContext() != null ? + requestParameters.getContext().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> isValidPersonalizeContext(e))) + : null; logger.info("User ID from request parameters. User ID: {}", userId); + if (context != null && !context.isEmpty()) { + logger.info("Personalize context provided in the search request"); + } + GetPersonalizedRankingRequest personalizeRequest = new GetPersonalizedRankingRequest() .withCampaignArn(rankerConfig.getPersonalizeCampaign()) .withInputList(documentIdsToRank) + .withContext(context) .withUserId(userId); GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest); @@ -103,4 +110,13 @@ public boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requ } return isValidPersonalizeConfig; } + + private String isValidPersonalizeContext(Map.Entry contextEntry) throws IllegalArgumentException { + if (contextEntry.getValue() instanceof String) { + return (String) contextEntry.getValue(); + } else { + throw new IllegalArgumentException("Personalize context value is not of type String. " + + "Invalid context value: " + contextEntry.getValue()); + } + } } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java index 1bad35b..eccb2ad 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java @@ -19,19 +19,19 @@ import org.opensearch.env.TestEnvironment; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; -import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder; import org.opensearch.test.OpenSearchTestCase; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.WeakHashMap; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; public class PersonalizeResponseProcessorTests extends OpenSearchTestCase { @@ -42,7 +42,7 @@ public class PersonalizeResponseProcessorTests extends OpenSearchTestCase { private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; private String iamRoleArn = ""; private String recipe = "sample-personalize-recipe"; - private String itemIdField = "ITEM_ID"; + private String itemIdField = ""; private String region = "us-west-2"; private double weight = 0.25; @@ -131,4 +131,82 @@ public void testProcessorWithHits() throws Exception { personalizeResponseProcessor.processResponse(searchRequest, searchResponse); } + + public void testProcessorWithHitsAndSearchProcessorExt() throws Exception { + PersonalizeClient mockClient = mock(PersonalizeClient.class); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = new HashMap<>(); + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration); + + Map personalizeContext = new HashMap<>(); + personalizeContext.put("contextKey2", "contextValue2"); + PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(personalizeRequestParams); + + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + SearchHit[] searchHits = new SearchHit[10]; + for (int i = 0; i < searchHits.length; i++) { + searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); + searchHits[i].score(1.0f); + } + SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); + + personalizeResponseProcessor.processResponse(searchRequest, searchResponse); + } + + public void testProcessorWithHitsWithInvalidPersonalizeContext() throws Exception { + PersonalizeClient mockClient = mock(PersonalizeClient.class); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = new HashMap<>(); + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration); + + Map personalizeContext = new HashMap<>(); + personalizeContext.put("contextKey2", 5); + PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(personalizeRequestParams); + + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + SearchHit[] searchHits = new SearchHit[10]; + for (int i = 0; i < searchHits.length; i++) { + searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); + searchHits[i].score(1.0f); + } + SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); + + personalizeResponseProcessor.processResponse(searchRequest, searchResponse); + } } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java index 61c31ef..b5e5ac9 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java @@ -19,6 +19,8 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import static org.mockito.ArgumentMatchers.any; @@ -59,4 +61,83 @@ public void testReRankWithoutItemIdFieldInConfig() throws IOException { SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } + + public void testReRankWithRequestParameterContext() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + Map context = new HashMap<>(); + context.put("contextKey", "contextValue"); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + requestParameters.setContext(context); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } + + public void testReRankWithInvalidRequestParameterContext() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + Map context = new HashMap<>(); + context.put("contextKey", 2); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + requestParameters.setContext(context); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } + + public void testReRankWithNoUserId() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + Map context = new HashMap<>(); + context.put("contextKey", "contextValue"); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setContext(context); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } + + public void testReRankWithEmptyItemIdField() throws IOException { + String itemIdEmpty = ""; + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdEmpty, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } + + public void testReRankWithNullItemIdField() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, null, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java index c997cb7..ea09d79 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java @@ -13,12 +13,14 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class PersonalizeRequestParameterUtilTests extends OpenSearchTestCase { public void testExtractParameters() { - PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_1"); + PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_1", new HashMap<>()); PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); extBuilder.setRequestParameters(expected); SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() @@ -27,4 +29,69 @@ public void testExtractParameters() { PersonalizeRequestParameters actual = PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); assertEquals(expected, actual); } + + public void testExtractParametersWithContext() { + Map context = new HashMap<>(); + context.put("contextKey", "contextValue"); + PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_1", context); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + PersonalizeRequestParameters actual = PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); + assertEquals(expected, actual); + } + + public void testPersonalizeRequestParametersEquals() { + Map notExpectedContext = new HashMap<>(); + notExpectedContext.put("contextKey", "contextValue"); + PersonalizeRequestParameters notExpected = new PersonalizeRequestParameters("user_1", notExpectedContext); + + Map expectedContext = new HashMap<>(); + expectedContext.put("contextKey2", "contextValue2"); + PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_1", expectedContext); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + PersonalizeRequestParameters actual = PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); + assertNotEquals(notExpected, actual); + } + + public void testPersonalizeRequestParametersContextMapDifferentSize() { + Map notExpectedContext = new HashMap<>(); + notExpectedContext.put("contextKey", "contextValue"); + PersonalizeRequestParameters notExpected = new PersonalizeRequestParameters("user_1", notExpectedContext); + + Map expectedContext = new HashMap<>(); + expectedContext.put("contextKey2", "contextValue2"); + expectedContext.put("contextKey22", "contextValue22"); + PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_1", expectedContext); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + PersonalizeRequestParameters actual = PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); + assertNotEquals(notExpected, actual); + } + + public void testPersonalizeRequestParametersUserIdDiffers() { + Map notExpectedContext = new HashMap<>(); + notExpectedContext.put("contextKey", "contextValue"); + PersonalizeRequestParameters notExpected = new PersonalizeRequestParameters("user_1", notExpectedContext); + + Map expectedContext = new HashMap<>(); + expectedContext.put("contextKey", "contextValue"); + PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_2", expectedContext); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + PersonalizeRequestParameters actual = PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); + assertNotEquals(notExpected, actual); + } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java index e2b287e..46b2f90 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java @@ -15,11 +15,15 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; public class PersonalizeRequestParametersExtBuilderTests extends OpenSearchTestCase { public void testXContentRoundTrip() throws IOException { - PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters("28"); + Map context = new HashMap<>(); + context.put("contextKey", "contextValue"); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters("28", context); PersonalizeRequestParametersExtBuilder personalizeExtBuilder = new PersonalizeRequestParametersExtBuilder(); personalizeExtBuilder.setRequestParameters(requestParameters); XContentType xContentType = randomFrom(XContentType.values()); @@ -36,6 +40,7 @@ public void testXContentRoundTrip() throws IOException { public void testStreamRoundTrip() throws IOException { PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); + requestParameters.setContext(new HashMap<>()); PersonalizeRequestParametersExtBuilder personalizeExtBuilder = new PersonalizeRequestParametersExtBuilder(); personalizeExtBuilder.setRequestParameters(requestParameters); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();