Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpelt committed Jul 17, 2019
1 parent 05260c9 commit b4c78ff
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .appveyor.yml
Expand Up @@ -33,5 +33,5 @@ install:
# Not a .NET project, we build in the install step instead
build: false
test_script:
- cd examples
- python -c "import msdnet"
- cd tests
- python traintest.py -v
4 changes: 4 additions & 0 deletions .gitignore
Expand Up @@ -7,6 +7,10 @@ examples/*.txt
examples/*.checkpoint
examples/*/*.tiff
examples/*/*/*.tiff
tests/*.h5
tests/*.png
tests/*.txt
tests/*.checkpoint
.vscode/


Expand Down
4 changes: 2 additions & 2 deletions .travis.yml
Expand Up @@ -13,5 +13,5 @@ install:

script:
- python setup.py install
- cd examples
- python -c "import msdnet"
- cd tests
- python traintest.py -v
63 changes: 63 additions & 0 deletions tests/traintest.py
@@ -0,0 +1,63 @@
import unittest
import msdnet
import numpy as np

class TestTraining(unittest.TestCase):

def test_regr_training(self):
dil = msdnet.dilations.IncrementDilations(3)
n = msdnet.network.MSDNet(10, dil, 1, 1, gpu=False)
n.initialize()
tgt_im = np.zeros((1,128,128), dtype=np.float32)
tgt_im[:,32:96,32:96] = 1
inp_im = (tgt_im + np.random.normal(size=tgt_im.shape)).astype(np.float32)
d = msdnet.data.ArrayDataPoint(inp_im, tgt_im)
bprov = msdnet.data.BatchProvider([d,],1)
val = msdnet.validate.MSEValidation([d,])
t = msdnet.train.AdamAlgorithm(n)
n.normalizeinout([d,])
log1 = msdnet.loggers.ConsoleLogger()
log2 = msdnet.loggers.FileLogger('test_log_regr.txt')
log3 = msdnet.loggers.ImageLogger('test_log_regr')
for i in range(10):
t.step(n, bprov.getbatch())
t.to_file('test_regr_params.h5')
val.to_file('test_regr_params.h5')
n.to_file('test_regr_params.h5')
val.validate(n)
log1.log(val)
log2.log(val)
log3.log(val)
n_load = msdnet.network.MSDNet.from_file('test_regr_params.h5')
n_load.forward(inp_im)

def test_segm_training(self):
dil = msdnet.dilations.IncrementDilations(3)
n = msdnet.network.SegmentationMSDNet(10, dil, 1, 2, gpu=False)
n.initialize()
tgt_im = np.zeros((1,128,128), dtype=np.uint8)
tgt_im[:,32:96,32:96] = 1
inp_im = (tgt_im + np.random.normal(size=tgt_im.shape)).astype(np.float32)
d = msdnet.data.ArrayDataPoint(inp_im, tgt_im)
d = msdnet.data.OneHotDataPoint(d,[0,1])
bprov = msdnet.data.BatchProvider([d,],1)
val = msdnet.validate.MSEValidation([d,])
t = msdnet.train.AdamAlgorithm(n)
n.normalizeinout([d,])
log1 = msdnet.loggers.ConsoleLogger()
log2 = msdnet.loggers.FileLogger('test_log_segm.txt')
log3 = msdnet.loggers.ImageLabelLogger('test_log_segm')
for i in range(10):
t.step(n, bprov.getbatch())
t.to_file('test_segm_params.h5')
val.to_file('test_segm_params.h5')
n.to_file('test_segm_params.h5')
val.validate(n)
log1.log(val)
log2.log(val)
log3.log(val)
n_load = msdnet.network.MSDNet.from_file('test_segm_params.h5')
n_load.forward(inp_im)

if __name__ == '__main__':
unittest.main()

0 comments on commit b4c78ff

Please sign in to comment.