Skip to content

Commit

Permalink
add test/test_call_lua.py. Also add requirements.txt, test/requiremen…
Browse files Browse the repository at this point in the history
…ts.txt, pytest.ini
  • Loading branch information
hughperkins committed Sep 8, 2016
1 parent 60d3453 commit 1c771ea
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -65,3 +65,5 @@ docs/_build/

# PyBuilder
target/

test/junit-pytest-report.xml
3 changes: 3 additions & 0 deletions pytest.ini
@@ -0,0 +1,3 @@
[pytest]
addopts = --junitxml=test/junit-pytest-report.xml -ra
testpaths = test
2 changes: 2 additions & 0 deletions requirements.txt
@@ -0,0 +1,2 @@
numpy
docopt
1 change: 0 additions & 1 deletion simpleexample/luabit.lua
Expand Up @@ -31,4 +31,3 @@ function Luabit:printTable(sometable, somestring, table2)
end
return {bear='happy', result=12.345, foo='bar'}
end

5 changes: 1 addition & 4 deletions simpleexample/pybit.py
@@ -1,5 +1,3 @@
import sys
import os
import PyTorch
import PyTorchHelpers
import numpy as np
Expand All @@ -25,6 +23,5 @@
print('outTensor', outTensor)

res = luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234}, 'mistletoe', {
'row1': 'col1', 'meta': 'data'})
'row1': 'col1', 'meta': 'data'})
print('res', res)

2 changes: 2 additions & 0 deletions test/requirements.txt
@@ -0,0 +1,2 @@
pytest
python-mnist
38 changes: 38 additions & 0 deletions test/test_call_lua.lua
@@ -0,0 +1,38 @@
require 'torch'
require 'nn'

local TestCallLua = torch.class('TestCallLua')

function TestCallLua:__init(someName)
print('TestCallLua:__init(', someName, ')')
self.someName = someName
end

function TestCallLua:getName()
return self.someName
end

function TestCallLua:getOut(inTensor, outSize, kernelSize)
local inSize = inTensor:size(3)
local m = nn.TemporalConvolution(inSize, outSize, kernelSize)
m:float()
local out = m:forward(inTensor)
print('out from lua', out)
return out
end

function TestCallLua:addOne(inTensor)
local outTensor = inTensor + 3
return outTensor
end

function TestCallLua:printTable(sometable, somestring, table2)
for k, v in pairs(sometable) do
print('TestCallLua:printTable ', k, v)
end
print('somestring', somestring)
for k, v in pairs(table2) do
print('TestCallLua table2 ', k, v)
end
return {bear='happy', result=12.345, foo='bar'}
end
40 changes: 40 additions & 0 deletions test/test_call_lua.py
@@ -0,0 +1,40 @@
import PyTorch
import PyTorchHelpers
import numpy as np


def test_call_lua():
TestCallLua = PyTorchHelpers.load_lua_class('test/test_call_lua.lua', 'TestCallLua')

batchSize = 2
numFrames = 4
inSize = 3
outSize = 3
kernelSize = 3

luabit = TestCallLua('green')
print(luabit.getName())
assert luabit.getName() == 'green'
print('type(luabit)', type(luabit))
assert str(type(luabit)) == '<class \'PyTorchLua.TestCallLua\'>'

np.random.seed(123)
inTensor = np.random.randn(batchSize, numFrames, inSize).astype('float32')
luain = PyTorch.asFloatTensor(inTensor)

luaout = luabit.getOut(luain, outSize, kernelSize)

outTensor = luaout.asNumpyTensor()
print('outTensor', outTensor)
# I guess we just assume if we got to this point, with no exceptions, then thats a good thing...
# lets add some new test...

outTensor = luabit.addOne(luain).asNumpyTensor()
assert isinstance(outTensor, np.ndarray)
assert inTensor.shape == outTensor.shape
assert np.abs((inTensor + 3) - outTensor).max() < 1e-4

res = luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234}, 'mistletoe', {
'row1': 'col1', 'meta': 'data'})
print('res', res)
assert res == {'foo': 'bar', 'result': 12.345, 'bear': 'happy'}

0 comments on commit 1c771ea

Please sign in to comment.