Skip to content
This repository has been archived by the owner on Oct 30, 2019. It is now read-only.

Commit

Permalink
Fix shareGradInput with precision specification. (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan authored and colesbury committed Feb 9, 2017
1 parent 17a0138 commit ef12212
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/init.lua
Expand Up @@ -115,16 +115,16 @@ function M.shareGradInput(model, opt)
if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then
local key = sharingKey(m)
if cache[key] == nil then
cache[key] = torch[self.opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')]()(1)
cache[key] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
end
m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0)
end
end)
for i, m in ipairs(model:findModules('nn.ConcatTable')) do
if cache[i % 2] == nil then
cache[i % 2] = torch.CudaStorage(1)
cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
end
m.gradInput = torch[self.opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](cache[i % 2], 1, 0)
m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0)
end
end

Expand Down

0 comments on commit ef12212

Please sign in to comment.