Skip to content
Permalink
Browse files

tests(base): add unit test for load a dumped pipeline from yaml

  • Loading branch information...
hanxiao committed Jul 22, 2019
1 parent af7b2f8 commit 499682ce942c5fac778d8c09f40f95606439114d
Showing with 24 additions and 2 deletions.
  1. +24 −2 tests/test_load_dump_pipeline.py
@@ -1,7 +1,17 @@
import os
import unittest

from gnes.encoder.base import BaseEncoder
from gnes.encoder.base import BaseEncoder, PipelineEncoder


class DummyTFEncoder(BaseEncoder):
def post_init(self):
import tensorflow as tf
self.a = tf.get_variable(name='a', shape=[])
self.sess = tf.Session()

def encode(self, a, *args):
return self.sess.run(self.a + 1, feed_dict={self.a: a})


class TestLoadDumpPipeline(unittest.TestCase):
@@ -22,5 +32,17 @@ def test_base(self):
b = BaseEncoder.load_yaml(self.yaml_path)
self.assertTrue(b.is_trained)

def test_dummytf(self):
d1 = DummyTFEncoder()
self.assertEqual(d1.encode(1), 2)

d2 = DummyTFEncoder()
self.assertEqual(d2.encode(2), 3)

d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
self.assertEqual(d2.encode(1), 3)

def tearDown(self):
os.remove(self.dump_path)
if os.path.exists(self.dump_path):
os.remove(self.dump_path)

0 comments on commit 499682c

Please sign in to comment.
You can’t perform that action at this time.