Skip to content
This repository has been archived by the owner on Aug 28, 2021. It is now read-only.

Commit

Permalink
Fix bug in training
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuandong Tian committed Oct 18, 2016
1 parent e6f58a4 commit 2ed0c08
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion train/rl_framework/infra/bundle.lua
Expand Up @@ -182,7 +182,12 @@ local function get_top5(output, v)
local batchsize = output:size(1)
local top_accuracy = torch.FloatTensor(topn):zero()
local _, sorted_indices = output:sort(2, true)
-- require 'fb.debugger'.enter()
if torch.typename(sorted_indices) ~= 'torch.CudaTensor' then
sorted_indices2 = torch.CudaTensor(unpack(sorted_indices:size():totable()))
sorted_indices2:copy(sorted_indices)
sorted_indices = sorted_indices2
end

for i = 1, topn do
local accuracy = v:eq(sorted_indices:narrow(2, i, 1)):sum()
top_accuracy[i] = accuracy
Expand Down

0 comments on commit 2ed0c08

Please sign in to comment.