From 8c9cf4816b7b9d5a03b3998fb2119f69598514dd Mon Sep 17 00:00:00 2001 From: Mo Kweon Date: Thu, 24 Jun 2021 21:02:42 -0700 Subject: [PATCH] feat: implement CategoryFromVideo --- server/pkg/handlers/handlers.go | 8 +-- server/pkg/handlers/prutils/prutils.go | 67 +++++++++++++++++++++ server/pkg/handlers/prutils/prutils_test.go | 56 +++++++++++++++++ 3 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 server/pkg/handlers/prutils/prutils.go create mode 100644 server/pkg/handlers/prutils/prutils_test.go diff --git a/server/pkg/handlers/handlers.go b/server/pkg/handlers/handlers.go index 046750e9..aed12456 100644 --- a/server/pkg/handlers/handlers.go +++ b/server/pkg/handlers/handlers.go @@ -4,6 +4,7 @@ import ( "sort" "github.com/codingpot/pr12er/server/internal/err" + "github.com/codingpot/pr12er/server/pkg/handlers/prutils" "github.com/codingpot/pr12er/server/pkg/pr12er" ) @@ -19,7 +20,7 @@ func VideosResponseFromDB(db *pr12er.Database) *pr12er.GetVideosResponse { Title: dataVideo.GetVideoTitle(), Link: getYouTubeLinkFromID(dataVideo.GetVideoId()), Presenter: dataVideo.GetUploader(), - Category: getCategory(data), + Category: prutils.CategoryFromVideo(data), NumberOfLike: dataVideo.GetNumberOfLikes(), Keywords: getKeywords(data), NumberOfViews: dataVideo.GetNumberOfViews(), @@ -48,11 +49,6 @@ func getKeywords(prVideo *pr12er.PrVideo) []string { return ret } -// TODO: Implement getCategory based on papers. -func getCategory(prVideo *pr12er.PrVideo) pr12er.Category { - return pr12er.Category_CATEGORY_UNSPECIFIED -} - // getYouTubeLinkFromID returns the full URL. func getYouTubeLinkFromID(videoID string) string { return "https://youtu.be/" + videoID diff --git a/server/pkg/handlers/prutils/prutils.go b/server/pkg/handlers/prutils/prutils.go new file mode 100644 index 00000000..9058bcf8 --- /dev/null +++ b/server/pkg/handlers/prutils/prutils.go @@ -0,0 +1,67 @@ +// Package prutils contains util functions for PR model. +package prutils + +import ( + "strings" + + "github.com/codingpot/pr12er/server/pkg/pr12er" + log "github.com/sirupsen/logrus" +) + +// NOTE. +// Each keywords should be all lower cased. + +var visionKeywords = []string{ + "vision", + "detect", +} + +var nlpKeywords = []string{ + "text", + "sentence", +} + +var ocrKeywords = []string{ + "ocr", +} + +var audioKeywords = []string{ + "audio", +} + +var recommendationSystemKeywords = []string{ + "recommend", +} + +// CategoryFromVideo TODO: Convert to more sophisticated algorithm. +func CategoryFromVideo(prVideo *pr12er.PrVideo) pr12er.Category { + title := strings.ToLower(prVideo.GetVideo().GetVideoTitle()) + + //nolint:gocritic + if containsAnyElem(title, visionKeywords) { + return pr12er.Category_CATEGORY_VISION + } else if containsAnyElem(title, nlpKeywords) { + return pr12er.Category_CATEGORY_NLP + } else if containsAnyElem(title, ocrKeywords) { + return pr12er.Category_CATEGORY_OCR + } else if containsAnyElem(title, audioKeywords) { + return pr12er.Category_CATEGORY_AUDIO + } else if containsAnyElem(title, recommendationSystemKeywords) { + return pr12er.Category_CATEGORY_RS + } + + return pr12er.Category_CATEGORY_UNSPECIFIED +} + +func containsAnyElem(title string, keywords []string) bool { + for _, keyword := range keywords { + if strings.Contains(title, keyword) { + log.WithFields(log.Fields{ + "title": title, + "keyword": keyword, + }).Info("found category keyword") + return true + } + } + return false +} diff --git a/server/pkg/handlers/prutils/prutils_test.go b/server/pkg/handlers/prutils/prutils_test.go new file mode 100644 index 00000000..7a707192 --- /dev/null +++ b/server/pkg/handlers/prutils/prutils_test.go @@ -0,0 +1,56 @@ +package prutils + +import ( + "testing" + + "github.com/codingpot/pr12er/server/pkg/pr12er" + "github.com/stretchr/testify/assert" +) + +func TestCategoryFromVideo(t *testing.T) { + type args struct { + prVideo *pr12er.PrVideo + } + tests := []struct { + name string + args args + want pr12er.Category + }{ + { + name: "Recommender Systems should return RS category", + args: args{ + prVideo: &pr12er.PrVideo{Video: &pr12er.YouTubeVideo{ + VideoTitle: "PR-064: Wide & Deep Learning for Recommender Systems", + }}, + }, + want: pr12er.Category_CATEGORY_RS, + }, + { + name: "Audio paper should return Audio category", + args: args{ + prVideo: &pr12er.PrVideo{Video: &pr12er.YouTubeVideo{ + VideoTitle: "PR-067: Audio Super Resolution using Neural Nets", + }}, + }, + want: pr12er.Category_CATEGORY_AUDIO, + }, + { + name: "Vision paper returns Vision category", + args: args{ + prVideo: &pr12er.PrVideo{ + Video: &pr12er.YouTubeVideo{ + VideoTitle: "PR-084 MegDet: A Large Mini-Batch Object Detector", + }, + }, + }, + want: pr12er.Category_CATEGORY_VISION, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CategoryFromVideo(tt.args.prVideo) + assert.Equalf(t, tt.want, got, "want %s, got %s", tt.want, got) + }) + } +}