Skip to content

Commit

Permalink
Improved layout/readability of code
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelschwier committed Sep 4, 2018
1 parent 7358d8a commit af77322
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
6 changes: 6 additions & 0 deletions framework/modelhubapi/pythonapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, model, contrib_src_dir):
this_dir = os.path.dirname(os.path.realpath(__file__))
self.framework_dir = os.path.normpath(os.path.join(this_dir, ".."))


def get_config(self):
"""
Returns:
Expand All @@ -24,6 +25,7 @@ def get_config(self):
config_file_path = self.contrib_src_dir + "/model/config.json"
return self._load_json(config_file_path)


def get_legal(self):
"""
Returns:
Expand All @@ -45,6 +47,7 @@ def get_legal(self):
legal.update(self._load_txt_as_dict(contrib_license_dir + "/sample_data", "sample_data_license"))
return legal


def get_model_io(self):
"""
Returns:
Expand Down Expand Up @@ -79,6 +82,7 @@ def get_samples(self):
except Exception as e:
return {'error': repr(e)}


def predict(self, input_file_path, numpyToList = False):
"""
Preforms the model's inference on the given input.
Expand Down Expand Up @@ -134,6 +138,7 @@ def _load_txt_as_dict(self, file_path, return_key):
except Exception as e:
return {'error': str(e)}


def _load_json(self, file_path):
try:
with io.open(file_path, mode='r', encoding='utf-8') as f:
Expand All @@ -142,6 +147,7 @@ def _load_json(self, file_path):
except Exception as e:
return {'error': str(e)}


def _correct_output_list_wrapping(self, output):
if not isinstance(output, list):
return [output]
Expand Down
14 changes: 13 additions & 1 deletion framework/modelhubapi/restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, model, contrib_src_dir):
self.app.add_url_rule('/api/predict_sample', 'predict_sample',
self.predict_sample)


def get_config(self):
"""
GET method
Expand All @@ -47,6 +48,7 @@ def get_config(self):
"""
return self._jsonify(self.api.get_config())


def get_legal(self):
"""
GET method
Expand All @@ -65,6 +67,7 @@ def get_legal(self):
"""
return self._jsonify(self.api.get_legal())


def get_model_io(self):
"""
GET method
Expand All @@ -77,6 +80,7 @@ def get_model_io(self):
"""
return self._jsonify(self.api.get_model_io())


def get_model_files(self):
"""
GET method
Expand All @@ -100,6 +104,7 @@ def get_model_files(self):
except Exception as e:
return self._jsonify({'error': str(e)})


def get_samples(self):
"""
GET method
Expand All @@ -117,6 +122,7 @@ def get_samples(self):
except Exception as e:
return self._jsonify({'error': str(e)})


def predict(self):
"""
GET/POST method
Expand Down Expand Up @@ -178,6 +184,7 @@ def predict(self):
except Exception as e:
return self._jsonify({'error': str(e)})


def predict_sample(self):
"""
GET method
Expand Down Expand Up @@ -209,12 +216,14 @@ def predict_sample(self):
except Exception as e:
return self._jsonify({'error': str(e)})


def start(self):
"""
Starts the flask app.
"""
self.app.run(host='0.0.0.0', port=80, threaded=True)


# -------------------------------------------------------------------------
# Private helper functions
# -------------------------------------------------------------------------
Expand All @@ -240,16 +249,17 @@ def _samples(self, sample_name):
"""
Routing function for sample files that exist in contrib_src.
"""

return send_from_directory(self.contrib_src_dir + "/sample_data/", sample_name, cache_timeout=-1)


def _thumbnail(self, thumbnail_name):
"""
Routing function for the thumbnail that exists in contrib_src. The
thumbnail must be named "thumbnail.jpg".
"""
return send_from_directory(self.contrib_src_dir + "/model/", thumbnail_name, cache_timeout=-1)


def _get_file_name(self, mime_type):
"""
This utility function get the current date/time and returns a full path
Expand All @@ -261,5 +271,7 @@ def _get_file_name(self, mime_type):
mime_type.split("/")[1]))
return file_name


def _get_allowed_extensions(self):
return self.api.get_model_io()["input"]["format"]

0 comments on commit af77322

Please sign in to comment.