Skip to content

Commit

Permalink
fix panic if zero IDF (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Nov 9, 2022
1 parent 36ecca3 commit e2ab769
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
24 changes: 20 additions & 4 deletions master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error {
t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback))
// inverse document frequency of users
for i := range dataset.UserFeedback {
userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i])))
if dataset.ItemCount() == len(dataset.UserFeedback[i]) {
userIDF[i] = 1
} else {
userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i])))
}
}
t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback))
}
Expand All @@ -303,7 +307,11 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error {
t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemLabels))
// inverse document frequency of labels
for i := range labeledItems {
labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i])))
if dataset.ItemCount() == len(labeledItems[i]) {
labelIDF[i] = 1
} else {
labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i])))
}
}
t.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems))
}
Expand Down Expand Up @@ -597,7 +605,11 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error {
t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback))
// inverse document frequency of items
for i := range dataset.ItemFeedback {
itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i])))
if dataset.UserCount() == len(dataset.ItemFeedback[i]) {
itemIDF[i] = 1
} else {
itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i])))
}
}
t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback))
}
Expand All @@ -614,7 +626,11 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error {
t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserLabels))
// inverse document frequency of labels
for i := range labeledUsers {
labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i])))
if dataset.UserCount() == len(labeledUsers[i]) {
labelIDF[i] = 1
} else {
labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i])))
}
}
t.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers))
}
Expand Down
88 changes: 88 additions & 0 deletions master/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,50 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) {
assert.Equal(t, task.StatusComplete, m.taskMonitor.Tasks[TaskFindItemNeighbors].Status)
}

func TestMaster_FindItemNeighborsIVF_ZeroIDF(t *testing.T) {
// create mock master
m := newMockMaster(t)
defer m.Close()
// create config
m.Config = &config.Config{}
m.Config.Recommend.CacheSize = 3
m.Config.Master.NumJobs = 4
m.Config.Recommend.ItemNeighbors.EnableIndex = true
m.Config.Recommend.ItemNeighbors.IndexRecall = 1
m.Config.Recommend.ItemNeighbors.IndexFitEpoch = 10

// create dataset
err := m.DataClient.BatchInsertItems([]data.Item{
{"0", false, []string{"*"}, time.Now(), []string{"a"}, ""},
{"1", false, []string{"*"}, time.Now(), []string{"a"}, ""},
})
assert.NoError(t, err)
err = m.DataClient.BatchInsertFeedback([]data.Feedback{
{FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "0"}},
{FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "1"}},
}, true, true, true)
assert.NoError(t, err)
dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator())
assert.NoError(t, err)
m.rankingTrainSet = dataset

// similar items (common users)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated
neighborTask := NewFindItemNeighborsTask(&m.Master)
assert.NoError(t, neighborTask.run(nil))
similar, err := m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "0"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"1"}, cache.RemoveScores(similar))

// similar items (common labels)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar
neighborTask = NewFindItemNeighborsTask(&m.Master)
assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "0"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"1"}, cache.RemoveScores(similar))
}

func TestMaster_FindUserNeighborsBruteForce(t *testing.T) {
// create mock master
m := newMockMaster(t)
Expand Down Expand Up @@ -421,6 +465,50 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) {
assert.Equal(t, task.StatusComplete, m.taskMonitor.Tasks[TaskFindUserNeighbors].Status)
}

func TestMaster_FindUserNeighborsIVF_ZeroIDF(t *testing.T) {
// create mock master
m := newMockMaster(t)
defer m.Close()
// create config
m.Config = &config.Config{}
m.Config.Recommend.CacheSize = 3
m.Config.Master.NumJobs = 4
m.Config.Recommend.UserNeighbors.EnableIndex = true
m.Config.Recommend.UserNeighbors.IndexRecall = 1
m.Config.Recommend.UserNeighbors.IndexFitEpoch = 10

// create dataset
err := m.DataClient.BatchInsertUsers([]data.User{
{"0", []string{"a"}, nil, ""},
{"1", []string{"a"}, nil, ""},
})
assert.NoError(t, err)
err = m.DataClient.BatchInsertFeedback([]data.Feedback{
{FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "0"}},
{FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "1", ItemId: "0"}},
}, true, true, true)
assert.NoError(t, err)
dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator())
assert.NoError(t, err)
m.rankingTrainSet = dataset

// similar users (common items)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated
neighborTask := NewFindUserNeighborsTask(&m.Master)
assert.NoError(t, neighborTask.run(nil))
similar, err := m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "0"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"1"}, cache.RemoveScores(similar))

// similar users (common labels)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar
neighborTask = NewFindUserNeighborsTask(&m.Master)
assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "0"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"1"}, cache.RemoveScores(similar))
}

func TestMaster_LoadDataFromDatabase(t *testing.T) {
// create mock master
m := newMockMaster(t)
Expand Down

0 comments on commit e2ab769

Please sign in to comment.