Skip to content
This repository has been archived by the owner on Jul 27, 2022. It is now read-only.

Commit

Permalink
Merge pull request #30 from deepmind/fix_erf
Browse files Browse the repository at this point in the history
Fix problem in calling row vectors
  • Loading branch information
d11 committed May 5, 2015
2 parents 6ddf31f + 21d1104 commit 459ddec
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
6 changes: 4 additions & 2 deletions luasrc/error_handling.lua
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ end
Process the optional return storage, the sizes of the parameter functions, etc
@param K number of actual parameters for the sampler
@param K number of actual parameters required by the sampler
@param defaultResultType Tensor class corresponding to the expected result type (e.g. torch.DoubleTensor, torch.IntegerTensor, etc)
@param ... List of all parameters passed to the original caller
Expand Down Expand Up @@ -339,8 +339,10 @@ function cephes._check1DParams(K, defaultResultType, ...)
result:resize(Nresult)
end
-- Expand parameters which are of the wrong size. Note: they have
-- to be single-element to be expanded
for paramIndex, param in ipairs(params) do
if param:size(1) == 1 then
if param:nElement() == 1 then
local sizes = param:size()
sizes[1] = Nparams
params[paramIndex] = params[paramIndex]:expand(sizes)
Expand Down
45 changes: 45 additions & 0 deletions luasrc/tests/test_vectorized.lua
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,50 @@ function vectorizeTests.testNdtr()
tester:asserteq(result:size(1), n, "should get " .. n .. " results")
end


function vectorizeTests.testSingleRowInput()
local x = torch.Tensor({{1, 2}})
local a = cephes.erf(x)
tester:assertTensorEq(a,
torch.Tensor{cephes.erf(1),
cephes.erf(2)},
1e-16,
'Wrong output')
end


function vectorizeTests.testSingleRowInputWithResult()
local x = torch.Tensor({{1, 2}})
local a = torch.Tensor(2)
cephes.erf(a, x)
tester:assertTensorEq(a,
torch.Tensor{cephes.erf(1),
cephes.erf(2)},
1e-16,
'Wrong output')
end


function vectorizeTests.testSingleColumnInput()
local x = torch.Tensor{{1}, {2}}
local a = cephes.erf(x)
tester:assertTensorEq(a,
torch.Tensor{cephes.erf(1),
cephes.erf(2)},
1e-16,
'Wrong output')
end


function vectorizeTests.testSingleColumnInputWithResult()
local x = torch.Tensor{{1}, {2}}
local a = torch.Tensor(2)
cephes.erf(a, x)
tester:assertTensorEq(a,
torch.Tensor{cephes.erf(1),
cephes.erf(2)},
1e-16,
'Wrong output')
end
tester:add(vectorizeTests)
return tester:run()

0 comments on commit 459ddec

Please sign in to comment.