Skip to content

Commit

Permalink
adjusts tests for new output folder
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedhosny committed Nov 1, 2018
1 parent cd057c5 commit de85c22
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 72 deletions.
49 changes: 30 additions & 19 deletions framework/modelhubapi_tests/apitestbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,35 @@ class TestAPIBase(unittest.TestCase):
"""
Defines common functionality for api test cases.
This mainly defines convenience asserts to avoid duplicating them in
This mainly defines convenience asserts to avoid duplicating them in
the actual api test cases (since the difference in the APIs is mostly
how they are called, but not the actual results we are expecting)
"""

def setup_self_temp_output_dir(self):
self.temp_output_dir = os.path.join(self.this_dir, "temp_output_dir")
if not os.path.exists(self.temp_output_dir):
os.makedirs(self.temp_output_dir)

def assert_config_contains_correct_dict(self, config):
self.assertEqual("MockId", config["id"])
self.assertEqual("MockNet", config["meta"]["name"])


def assert_legal_contains_expected_keys(self, legal):
self.assertNotIn("error", legal)
keys = sorted(legal)
referenceKeys = ["model_license",
"modelhub_acknowledgements",
"modelhub_license",
referenceKeys = ["model_license",
"modelhub_acknowledgements",
"modelhub_license",
"sample_data_license"]
self.assertListEqual(referenceKeys, keys)


def assert_legal_contains_expected_mock_values(self, legal):
self.assertEqual("TEST MODEL LICENSE", legal["model_license"])
self.assertEqual("TEST SAMPLE DATA LICENSE", legal["sample_data_license"])


def assert_model_io_contains_expected_mock_values(self, model_io):
self.assertListEqual(["image/png"], model_io["input"]["format"])
Expand All @@ -46,15 +51,15 @@ def assert_model_io_contains_expected_mock_values(self, model_io):
self.assertEqual("mask_image", model_io["output"][1]["type"])


def assert_predict_contains_expected_mock_prediction(self, result, expectNumpy = False):
def assert_predict_contains_expected_mock_prediction(self, result, expectList = False):
self.assertEqual("class_0", result["output"][0]["prediction"][0]["label"])
self.assertEqual(0.3, result["output"][0]["prediction"][0]["probability"])
self.assertEqual("class_1", result["output"][0]["prediction"][1]["label"])
self.assertEqual(0.7, result["output"][0]["prediction"][1]["probability"])
if expectNumpy:
result["output"][1]["prediction"] = result["output"][1]["prediction"].tolist()
self.assertListEqual([[0,1,1,0],[0,2,2,0]], result["output"][1]["prediction"])

if expectList:
self.assertListEqual([[0,1,1,0],[0,2,2,0]], result["output"][1]["prediction"])
else:
self.assertIsInstance(result["output"][1]["prediction"], basestring)

def assert_predict_contains_expected_mock_meta_info(self, result):
self.assertEqual("label_list", result["output"][0]["type"])
Expand All @@ -73,18 +78,25 @@ class TestRESTAPIBase(TestAPIBase):
Defines common functionality for rest api test cases
"""

def setup_self_temp_workdir(self):
self.temp_workdir = os.path.join(self.this_dir, "temp_workdir")
if not os.path.exists(self.temp_workdir):
os.makedirs(self.temp_workdir)

def setup_self_temp_work_dir(self):
self.temp_work_dir = os.path.join(self.this_dir, "temp_work_dir")
if not os.path.exists(self.temp_work_dir):
os.makedirs(self.temp_work_dir)

def setup_self_temp_output_dir(self):
self.temp_output_dir = os.path.join(self.this_dir, "temp_output_dir")
if not os.path.exists(self.temp_output_dir):
os.makedirs(self.temp_output_dir)

def setup_self_test_client(self, model, contrib_src_dir):
rest_api = ModelHubRESTAPI(model, self.contrib_src_dir)
rest_api.working_folder = self.temp_workdir
rest_api.working_folder = self.temp_work_dir
rest_api.api.output_folder = self.temp_output_dir
app = rest_api.app
app.config["TESTING"] = True
self.client = app.test_client()


#--------------------------------------------------------------------------
# Private helper/convenience functions
#--------------------------------------------------------------------------
Expand All @@ -102,4 +114,3 @@ def _post_predict_request_on_sample_image(self, sample_image_name):

if __name__ == '__main__':
unittest.main()

44 changes: 25 additions & 19 deletions framework/modelhubapi_tests/pythonapi_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import os
import numpy
import shutil
from modelhubapi import ModelHubAPI
from .apitestbase import TestAPIBase
from .mockmodel.contrib_src.inference import Model
Expand All @@ -16,9 +17,12 @@ def setUp(self):
self.this_dir = os.path.dirname(os.path.realpath(__file__))
contrib_src_dir = os.path.join(self.this_dir, "mockmodel", "contrib_src")
self.api = ModelHubAPI(model, contrib_src_dir)
self.setup_self_temp_output_dir()
self.api.output_folder = self.temp_output_dir


def tearDown(self):
shutil.rmtree(self.temp_output_dir, ignore_errors=True)
pass


Expand All @@ -30,17 +34,17 @@ def test_get_config_returns_no_error(self):
def test_get_config_returns_correct_dict(self):
config = self.api.get_config()
self.assert_config_contains_correct_dict(config)


def test_get_legal_returns_expected_keys(self):
legal = self.api.get_legal()
self.assert_legal_contains_expected_keys(legal)


def test_get_legal_returns_expected_mock_values(self):
legal = self.api.get_legal()
self.assert_legal_contains_expected_mock_values(legal)


def test_get_model_io_returns_expected_mock_values(self):
model_io = self.api.get_model_io()
Expand All @@ -52,28 +56,31 @@ def test_get_samples_returns_path_to_mock_samples(self):
self.assertEqual(self.this_dir + "/mockmodel/contrib_src/sample_data", samples["folder"])
samples["files"].sort()
self.assertListEqual(["testimage_ramp_4x2.jpg",
"testimage_ramp_4x2.png"],
"testimage_ramp_4x2.png"],
samples["files"])


def test_predict_returns_expected_mock_prediction(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png")
self.assert_predict_contains_expected_mock_prediction(result, expectNumpy=True)


def test_predict_returns_expected_mock_prediction_list(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png", numpyToFile=False)
self.assert_predict_contains_expected_mock_prediction(result, expectList=True)

def test_predict_returns_expected_mock_prediction_url(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png", numpyToFile=True)
self.assert_predict_contains_expected_mock_prediction(result)

def test_predict_returns_expected_mock_meta_info(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png")
self.assert_predict_contains_expected_mock_meta_info(result)


def test_predict_returns_correct_output_format(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png")
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png", numpyToFile=False)
self.assertIsInstance(result["output"], list)
self.assertIsInstance(result["output"][0]["prediction"], list)
self.assertIsInstance(result["output"][0]["prediction"][0], dict)
self.assertIsInstance(result["output"][0]["prediction"][1], dict)
self.assertIsInstance(result["output"][1]["prediction"], numpy.ndarray)
self.assertIsInstance(result["output"][1]["prediction"], list)


def test_predict_output_types_match_config(self):
model_io = self.api.get_model_io()
Expand All @@ -97,15 +104,15 @@ def tearDown(self):
pass


def test_predict_returns_expected_mock_prediction(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png")
self.assertListEqual([[0,1,1,0],[0,2,2,0]], result["output"][0]["prediction"].tolist())
def test_predict_returns_expected_mock_prediction_list(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png", numpyToFile=False)
self.assertListEqual([[0,1,1,0],[0,2,2,0]], result["output"][0]["prediction"])


def test_predict_returns_correct_output_format(self):
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png")
result = self.api.predict(self.this_dir + "/mockmodel/contrib_src/sample_data/testimage_ramp_4x2.png", numpyToFile=False)
self.assertIsInstance(result["output"], list)
self.assertIsInstance(result["output"][0]["prediction"], numpy.ndarray)
self.assertIsInstance(result["output"][0]["prediction"], list)



Expand Down Expand Up @@ -162,4 +169,3 @@ def setUp(self):

if __name__ == '__main__':
unittest.main()

3 changes: 1 addition & 2 deletions framework/modelhubapi_tests/pythonapivoid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def setUp(self):

def tearDown(self):
pass


def test_get_config_returns_error(self):
result = self.api.get_config()
Expand Down Expand Up @@ -49,4 +49,3 @@ def test_predict_returns_NotImplementedError(self):

if __name__ == '__main__':
unittest.main()

38 changes: 19 additions & 19 deletions framework/modelhubapi_tests/restapi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,36 @@ class TestModelHubRESTAPI(TestRESTAPIBase):
def setUp(self):
self.this_dir = os.path.dirname(os.path.realpath(__file__))
self.contrib_src_dir = os.path.join(self.this_dir, "mockmodel", "contrib_src")
self.setup_self_temp_workdir()
self.setup_self_temp_work_dir()
self.setup_self_temp_output_dir()
self.setup_self_test_client(Model(), self.contrib_src_dir)


def tearDown(self):
shutil.rmtree(self.temp_workdir, ignore_errors=True)
shutil.rmtree(self.temp_work_dir, ignore_errors=True)
shutil.rmtree(self.temp_output_dir, ignore_errors=True)
pass


def test_get_config_returns_correct_dict(self):
response = self.client.get("/api/get_config")
self.assertEqual(200, response.status_code)
config = json.loads(response.get_data())
self.assert_config_contains_correct_dict(config)


def test_get_legal_returns_expected_keys(self):
response = self.client.get("/api/get_legal")
self.assertEqual(200, response.status_code)
legal = json.loads(response.get_data())
self.assert_legal_contains_expected_keys(legal)


def test_get_legal_returns_expected_mock_values(self):
response = self.client.get("/api/get_legal")
self.assertEqual(200, response.status_code)
legal = json.loads(response.get_data())
self.assert_legal_contains_expected_mock_values(legal)


def test_get_model_io_returns_expected_mock_values(self):
response = self.client.get("/api/get_model_io")
Expand All @@ -56,21 +57,21 @@ def test_get_samples_returns_path_to_mock_samples(self):
samples = json.loads(response.get_data())
samples.sort()
self.assertListEqual(["http://localhost/api/samples/testimage_ramp_4x2.jpg",
"http://localhost/api/samples/testimage_ramp_4x2.png"],
"http://localhost/api/samples/testimage_ramp_4x2.png"],
samples)


def test_samples_routes_correct(self):
response = self.client.get("/api/samples/testimage_ramp_4x2.png")
self.assertEqual(200, response.status_code)
self.assertEqual("image/png", response.content_type)


def test_thumbnail_routes_correct(self):
response = self.client.get("/api/thumbnail/thumbnail.jpg")
self.assertEqual(200, response.status_code)
self.assertEqual("image/jpeg", response.content_type)


def test_get_model_files_returns_zip(self):
response = self.client.get("/api/get_model_files")
Expand All @@ -81,7 +82,7 @@ def test_get_model_files_returns_zip(self):
def test_get_model_files_returned_zip_has_mock_content(self):
response = self.client.get("/api/get_model_files")
self.assertEqual(200, response.status_code)
test_zip_name = self.temp_workdir + "/test_response.zip"
test_zip_name = self.temp_work_dir + "/test_response.zip"
with open(test_zip_name, "wb") as test_file:
test_file.write(response.get_data())
with ZipFile(test_zip_name, "r") as test_zip:
Expand All @@ -93,16 +94,16 @@ def test_get_model_files_returned_zip_has_mock_content(self):
zip_content = test_zip.namelist()
zip_content.sort()
self.assertListEqual(reference_content, zip_content)
self.assertEqual(b"EMPTY MOCK MODEL FOR UNIT TESTING",
self.assertEqual(b"EMPTY MOCK MODEL FOR UNIT TESTING",
test_zip.read("model/model.txt"))


def test_predict_by_post_returns_expected_mock_prediction(self):
response = self._post_predict_request_on_sample_image("testimage_ramp_4x2.png")
self.assertEqual(200, response.status_code)
result = json.loads(response.get_data())
self.assert_predict_contains_expected_mock_prediction(result)


def test_predict_by_post_returns_expected_mock_meta_info(self):
response = self._post_predict_request_on_sample_image("testimage_ramp_4x2.png")
Expand All @@ -122,7 +123,7 @@ def test_predict_by_post_returns_error_on_unsupported_file_type(self):
def test_working_folder_empty_after_predict_by_post(self):
response = self._post_predict_request_on_sample_image("testimage_ramp_4x2.png")
self.assertEqual(200, response.status_code)
self.assertEqual(len(os.listdir(self.temp_workdir) ), 0)
self.assertEqual(len(os.listdir(self.temp_work_dir) ), 0)


# TODO this is not so nice yet, test should not require a download from the inet
Expand All @@ -132,7 +133,7 @@ def test_predict_by_url_returns_expected_mock_prediction(self):
self.assertEqual(200, response.status_code)
result = json.loads(response.get_data())
self.assert_predict_contains_expected_mock_prediction(result)


# TODO this is not so nice yet, test should not require a download from the inet
# should probably use a mock server for this
Expand All @@ -156,7 +157,7 @@ def test_predict_by_url_returns_error_on_unsupported_file_type(self):
def test_working_folder_empty_after_predict_by_url(self):
response = self.client.get("/api/predict?fileurl=https://raw.githubusercontent.com/modelhub-ai/modelhub-docker/master/framework/modelhublib_tests/testdata/testimage_ramp_4x2.png")
self.assertEqual(200, response.status_code)
self.assertEqual(len(os.listdir(self.temp_workdir) ), 0)
self.assertEqual(len(os.listdir(self.temp_work_dir) ), 0)


def test_predict_sample_returns_expected_mock_prediction(self):
Expand All @@ -174,4 +175,3 @@ def test_predict_sample_on_invalid_file_returns_error(self):

if __name__ == '__main__':
unittest.main()

0 comments on commit de85c22

Please sign in to comment.