Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): support multi-categories filtering #704

Merged
merged 1 commit into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions master/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
// parse arguments
recommender := request.PathParameter("recommender")
userId := request.PathParameter("user-id")
category := request.PathParameter("category")
categories := []string{request.PathParameter("category")}
n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN)
if err != nil {
server.BadRequest(response, err)
Expand All @@ -765,13 +765,13 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
var results []string
switch recommender {
case "offline":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendOffline)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendOffline)
case "collaborative":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendCollaborative)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendCollaborative)
case "user_based":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendUserBased)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendUserBased)
case "item_based":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendItemBased)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendItemBased)
case "_":
recommenders := []server.Recommender{m.RecommendOffline}
for _, recommender := range m.Config.Recommend.Online.FallbackRecommend {
Expand All @@ -791,7 +791,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
return
}
}
results, err = m.Recommend(ctx, response, userId, category, n, recommenders...)
results, err = m.Recommend(ctx, response, userId, categories, n, recommenders...)
}
if err != nil {
server.InternalServerError(response, err)
Expand Down
30 changes: 17 additions & 13 deletions server/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ func (s *RestServer) CreateWebService() {
Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}).
Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")).
Param(ws.PathParameter("user-id", "ID of the user to get recommendation").DataType("string")).
Param(ws.QueryParameter("category", "Category of the returned items (support multi-categories filtering)").DataType("string")).
Param(ws.QueryParameter("write-back-type", "Type of write back feedback").DataType("string")).
Param(ws.QueryParameter("write-back-delay", "Timestamp delay of write back feedback (format 0h0m0s)").DataType("string")).
Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")).
Expand Down Expand Up @@ -684,11 +685,11 @@ func (s *RestServer) getCollaborative(request *restful.Request, response *restfu
// 1. If there are recommendations in cache, return cached recommendations.
// 2. If there are historical interactions of the users, return similar items.
// 3. Otherwise, return fallback recommendation (popular/latest).
func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, userId, category string, n int, recommenders ...Recommender) ([]string, error) {
func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, userId string, categories []string, n int, recommenders ...Recommender) ([]string, error) {
initStart := time.Now()

// create context
recommendCtx, err := s.createRecommendContext(ctx, userId, category, n)
recommendCtx, err := s.createRecommendContext(ctx, userId, categories, n)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -727,7 +728,7 @@ func (s *RestServer) Recommend(ctx context.Context, response *restful.Response,
type recommendContext struct {
context context.Context
userId string
category string
categories []string
userFeedback []data.Feedback
n int
results []string
Expand All @@ -750,7 +751,7 @@ type recommendContext struct {
loadPopularTime time.Duration
}

func (s *RestServer) createRecommendContext(ctx context.Context, userId, category string, n int) (*recommendContext, error) {
func (s *RestServer) createRecommendContext(ctx context.Context, userId string, categories []string, n int) (*recommendContext, error) {
// pull historical feedback
userFeedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now())
if err != nil {
Expand All @@ -764,7 +765,7 @@ func (s *RestServer) createRecommendContext(ctx context.Context, userId, categor
}
return &recommendContext{
userId: userId,
category: category,
categories: categories,
n: n,
excludeSet: excludeSet,
userFeedback: userFeedback,
Expand All @@ -777,7 +778,7 @@ type Recommender func(ctx *recommendContext) error
func (s *RestServer) RecommendOffline(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
recommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.OfflineRecommend, ctx.userId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
recommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.OfflineRecommend, ctx.userId, ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -797,7 +798,7 @@ func (s *RestServer) RecommendOffline(ctx *recommendContext) error {
func (s *RestServer) RecommendCollaborative(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
collaborativeRecommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.CollaborativeRecommend, ctx.userId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
collaborativeRecommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.CollaborativeRecommend, ctx.userId, ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -836,7 +837,7 @@ func (s *RestServer) RecommendUserBased(ctx *recommendContext) error {
if err != nil {
return errors.Trace(err)
}
if ctx.category == "" || funk.ContainsString(item.Categories, ctx.category) {
if funk.Equal(ctx.categories, []string{""}) || funk.Subset(ctx.categories, item.Categories) {
candidates[feedback.ItemId] += user.Score
}
}
Expand Down Expand Up @@ -876,7 +877,7 @@ func (s *RestServer) RecommendItemBased(ctx *recommendContext) error {
candidates := make(map[string]float64)
for _, feedback := range userFeedback {
// load similar items
similarItems, err := s.CacheClient.SearchDocuments(ctx.context, cache.ItemNeighbors, feedback.ItemId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
similarItems, err := s.CacheClient.SearchDocuments(ctx.context, cache.ItemNeighbors, feedback.ItemId, ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -906,7 +907,7 @@ func (s *RestServer) RecommendItemBased(ctx *recommendContext) error {
func (s *RestServer) RecommendLatest(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.LatestItems, "", []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.LatestItems, "", ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -926,7 +927,7 @@ func (s *RestServer) RecommendLatest(ctx *recommendContext) error {
func (s *RestServer) RecommendPopular(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.PopularItems, "", []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.PopularItems, "", ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -955,7 +956,10 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re
BadRequest(response, err)
return
}
category := request.PathParameter("category")
categories := request.QueryParameters("category")
if len(categories) == 0 {
categories = []string{request.PathParameter("category")}
}
offset, err := ParseInt(request, "offset", 0)
if err != nil {
BadRequest(response, err)
Expand Down Expand Up @@ -986,7 +990,7 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re
return
}
}
results, err := s.Recommend(ctx, response, userId, category, offset+n, recommenders...)
results, err := s.Recommend(ctx, response, userId, categories, offset+n, recommenders...)
if err != nil {
InternalServerError(response, err)
return
Expand Down
30 changes: 30 additions & 0 deletions server/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,36 @@ func (suite *ServerTestSuite) TestGetRecommends() {
End()
}

func (suite *ServerTestSuite) TestGetRecommendsWithMultiCategories() {
ctx := context.Background()
t := suite.T()
// insert recommendation
err := suite.CacheClient.AddDocuments(ctx, cache.OfflineRecommend, "0", []cache.Document{
{Id: "1", Score: 1, Categories: []string{""}},
{Id: "2", Score: 2, Categories: []string{"", "2"}},
{Id: "3", Score: 3, Categories: []string{"", "3"}},
{Id: "4", Score: 4, Categories: []string{"", "2"}},
{Id: "5", Score: 5, Categories: []string{"", "5"}},
{Id: "6", Score: 6, Categories: []string{"", "2", "3"}},
{Id: "7", Score: 7, Categories: []string{"", "7"}},
{Id: "8", Score: 8, Categories: []string{"", "2"}},
{Id: "9", Score: 9, Categories: []string{"", "3"}},
})
suite.NoError(err)
apitest.New().
Handler(suite.handler).
Get("/api/recommend/0").
Header("X-API-Key", apiKey).
QueryCollection(map[string][]string{
"n": []string{"3"},
"category": []string{"2", "3"},
}).
Expect(t).
Status(http.StatusOK).
Body(suite.marshal([]string{"6"})).
End()
}

func (suite *ServerTestSuite) TestGetRecommendsWithReplacement() {
ctx := context.Background()
t := suite.T()
Expand Down