Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions server/pkg/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"sort"

"github.com/codingpot/pr12er/server/internal/err"
"github.com/codingpot/pr12er/server/pkg/handlers/prutils"
"github.com/codingpot/pr12er/server/pkg/pr12er"
)

Expand All @@ -19,7 +20,7 @@ func VideosResponseFromDB(db *pr12er.Database) *pr12er.GetVideosResponse {
Title: dataVideo.GetVideoTitle(),
Link: getYouTubeLinkFromID(dataVideo.GetVideoId()),
Presenter: dataVideo.GetUploader(),
Category: getCategory(data),
Category: prutils.CategoryFromVideo(data),
NumberOfLike: dataVideo.GetNumberOfLikes(),
Keywords: getKeywords(data),
NumberOfViews: dataVideo.GetNumberOfViews(),
Expand Down Expand Up @@ -48,11 +49,6 @@ func getKeywords(prVideo *pr12er.PrVideo) []string {
return ret
}

// TODO: Implement getCategory based on papers.
func getCategory(prVideo *pr12er.PrVideo) pr12er.Category {
return pr12er.Category_CATEGORY_UNSPECIFIED
}

// getYouTubeLinkFromID returns the full URL.
func getYouTubeLinkFromID(videoID string) string {
return "https://youtu.be/" + videoID
Expand Down
67 changes: 67 additions & 0 deletions server/pkg/handlers/prutils/prutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Package prutils contains util functions for PR model.
package prutils

import (
"strings"

"github.com/codingpot/pr12er/server/pkg/pr12er"
log "github.com/sirupsen/logrus"
)

// NOTE.
// Each keywords should be all lower cased.

var visionKeywords = []string{
"vision",
"detect",
}

var nlpKeywords = []string{
"text",
"sentence",
}

var ocrKeywords = []string{
"ocr",
}

var audioKeywords = []string{
"audio",
}

var recommendationSystemKeywords = []string{
"recommend",
}

// CategoryFromVideo TODO: Convert to more sophisticated algorithm.
func CategoryFromVideo(prVideo *pr12er.PrVideo) pr12er.Category {
title := strings.ToLower(prVideo.GetVideo().GetVideoTitle())

//nolint:gocritic
if containsAnyElem(title, visionKeywords) {
return pr12er.Category_CATEGORY_VISION
} else if containsAnyElem(title, nlpKeywords) {
return pr12er.Category_CATEGORY_NLP
} else if containsAnyElem(title, ocrKeywords) {
return pr12er.Category_CATEGORY_OCR
} else if containsAnyElem(title, audioKeywords) {
return pr12er.Category_CATEGORY_AUDIO
} else if containsAnyElem(title, recommendationSystemKeywords) {
return pr12er.Category_CATEGORY_RS
}

return pr12er.Category_CATEGORY_UNSPECIFIED
}

func containsAnyElem(title string, keywords []string) bool {
for _, keyword := range keywords {
if strings.Contains(title, keyword) {
log.WithFields(log.Fields{
"title": title,
"keyword": keyword,
}).Info("found category keyword")
return true
}
}
return false
}
56 changes: 56 additions & 0 deletions server/pkg/handlers/prutils/prutils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package prutils

import (
"testing"

"github.com/codingpot/pr12er/server/pkg/pr12er"
"github.com/stretchr/testify/assert"
)

func TestCategoryFromVideo(t *testing.T) {
type args struct {
prVideo *pr12er.PrVideo
}
tests := []struct {
name string
args args
want pr12er.Category
}{
{
name: "Recommender Systems should return RS category",
args: args{
prVideo: &pr12er.PrVideo{Video: &pr12er.YouTubeVideo{
VideoTitle: "PR-064: Wide & Deep Learning for Recommender Systems",
}},
},
want: pr12er.Category_CATEGORY_RS,
},
{
name: "Audio paper should return Audio category",
args: args{
prVideo: &pr12er.PrVideo{Video: &pr12er.YouTubeVideo{
VideoTitle: "PR-067: Audio Super Resolution using Neural Nets",
}},
},
want: pr12er.Category_CATEGORY_AUDIO,
},
{
name: "Vision paper returns Vision category",
args: args{
prVideo: &pr12er.PrVideo{
Video: &pr12er.YouTubeVideo{
VideoTitle: "PR-084 MegDet: A Large Mini-Batch Object Detector",
},
},
},
want: pr12er.Category_CATEGORY_VISION,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CategoryFromVideo(tt.args.prVideo)
assert.Equalf(t, tt.want, got, "want %s, got %s", tt.want, got)
})
}
}