Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add test/test_call_lua.py. Also add requirements.txt, test/requiremen…
…ts.txt, pytest.ini
- Loading branch information
1 parent
60d3453
commit 1c771ea
Showing
8 changed files
with
88 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,3 +65,5 @@ docs/_build/ | |
|
||
# PyBuilder | ||
target/ | ||
|
||
test/junit-pytest-report.xml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[pytest] | ||
addopts = --junitxml=test/junit-pytest-report.xml -ra | ||
testpaths = test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
numpy | ||
docopt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pytest | ||
python-mnist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
require 'torch' | ||
require 'nn' | ||
|
||
local TestCallLua = torch.class('TestCallLua') | ||
|
||
function TestCallLua:__init(someName) | ||
print('TestCallLua:__init(', someName, ')') | ||
self.someName = someName | ||
end | ||
|
||
function TestCallLua:getName() | ||
return self.someName | ||
end | ||
|
||
function TestCallLua:getOut(inTensor, outSize, kernelSize) | ||
local inSize = inTensor:size(3) | ||
local m = nn.TemporalConvolution(inSize, outSize, kernelSize) | ||
m:float() | ||
local out = m:forward(inTensor) | ||
print('out from lua', out) | ||
return out | ||
end | ||
|
||
function TestCallLua:addOne(inTensor) | ||
local outTensor = inTensor + 3 | ||
return outTensor | ||
end | ||
|
||
function TestCallLua:printTable(sometable, somestring, table2) | ||
for k, v in pairs(sometable) do | ||
print('TestCallLua:printTable ', k, v) | ||
end | ||
print('somestring', somestring) | ||
for k, v in pairs(table2) do | ||
print('TestCallLua table2 ', k, v) | ||
end | ||
return {bear='happy', result=12.345, foo='bar'} | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import PyTorch | ||
import PyTorchHelpers | ||
import numpy as np | ||
|
||
|
||
def test_call_lua(): | ||
TestCallLua = PyTorchHelpers.load_lua_class('test/test_call_lua.lua', 'TestCallLua') | ||
|
||
batchSize = 2 | ||
numFrames = 4 | ||
inSize = 3 | ||
outSize = 3 | ||
kernelSize = 3 | ||
|
||
luabit = TestCallLua('green') | ||
print(luabit.getName()) | ||
assert luabit.getName() == 'green' | ||
print('type(luabit)', type(luabit)) | ||
assert str(type(luabit)) == '<class \'PyTorchLua.TestCallLua\'>' | ||
|
||
np.random.seed(123) | ||
inTensor = np.random.randn(batchSize, numFrames, inSize).astype('float32') | ||
luain = PyTorch.asFloatTensor(inTensor) | ||
|
||
luaout = luabit.getOut(luain, outSize, kernelSize) | ||
|
||
outTensor = luaout.asNumpyTensor() | ||
print('outTensor', outTensor) | ||
# I guess we just assume if we got to this point, with no exceptions, then thats a good thing... | ||
# lets add some new test... | ||
|
||
outTensor = luabit.addOne(luain).asNumpyTensor() | ||
assert isinstance(outTensor, np.ndarray) | ||
assert inTensor.shape == outTensor.shape | ||
assert np.abs((inTensor + 3) - outTensor).max() < 1e-4 | ||
|
||
res = luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234}, 'mistletoe', { | ||
'row1': 'col1', 'meta': 'data'}) | ||
print('res', res) | ||
assert res == {'foo': 'bar', 'result': 12.345, 'bear': 'happy'} |