Skip to content
This repository has been archived by the owner on Oct 30, 2019. It is now read-only.

Commit

Permalink
Limit top-5 to top-N for N-way classifier
Browse files Browse the repository at this point in the history
Don't error when trying to compute top-5 for binary classifier

Fixes #14
  • Loading branch information
colesbury committed Feb 25, 2016
1 parent be6be02 commit eff127f
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions train.lua
Expand Up @@ -138,8 +138,12 @@ function Trainer:computeScore(output, target, nCrops)
local correct = predictions:eq(
target:long():view(batchSize, 1):expandAs(output))

local top1 = 1.0 - correct:narrow(2, 1, 1):sum() / batchSize
local top5 = 1.0 - correct:narrow(2, 1, 5):sum() / batchSize
-- Top-1 score
local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize)

-- Top-5 score, if there are at least 5 classes
local len = math.min(5, correct:size(2))
local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize)

return top1 * 100, top5 * 100
end
Expand Down

0 comments on commit eff127f

Please sign in to comment.