Skip to content

Commit

Permalink
Split interface into frontend and rest modules (#34)
Browse files Browse the repository at this point in the history
* factor helper functions into own class for better modularity

 When refactoring the main app into different flask blueprints I will have to have access to these methods from a central location. This static helper class will enable me to do this. The name is still quite stupid. I will have to think about it in more detail.

* refactor rest api into a flask blueprint

 This allows us to better separate logic from front-end in the future and easier extend the API and web front-end.
  • Loading branch information
jan-xyz authored and rhsimplex committed Aug 1, 2017
1 parent eb945a3 commit bda2c3d
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 188 deletions.
2 changes: 2 additions & 0 deletions picasso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from flask import Flask
import os
import sys
from picasso.interfaces.rest import API

if sys.version_info.major < 3 or (sys.version_info.major == 3 and
sys.version_info.minor < 5):
raise SystemError('Python 3.5+ required, found {}'.format(sys.version))

app = Flask(__name__)
app.config.from_object('picasso.config.Default')
app.register_blueprint(API, url_prefix='/api')

if os.getenv('PICASSO_SETTINGS'):
app.config.from_envvar('PICASSO_SETTINGS')
Expand Down
Empty file added picasso/interfaces/__init__.py
Empty file.
110 changes: 110 additions & 0 deletions picasso/interfaces/rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
"""Flask blueprint for accessing and manipulating image ressources
This is used by the main flask application to provide a REST API.
"""

import os
import shutil
from PIL import Image
from werkzeug.utils import secure_filename
from flask import (
Blueprint,
current_app,
jsonify,
session,
request
)
from picasso import __version__
from picasso.utils import get_visualizations

API = Blueprint('api', __name__)


@API.route('/', methods=['GET'])
def root():
"""The root of the REST API
displays a hello world message.
"""
return jsonify(message='Picasso {version}. '
'See API documentation at: '
'https://picasso.readthedocs.io/en/latest/api.html'
.format(version=__version__),
version=__version__)


@API.route('/images', methods=['POST', 'GET'])
def images():
"""Upload images via REST interface
Check if file upload was successful and sanatize user input.
TODO: return file URL instead of filename
"""
if request.method == 'POST':
file_upload = request.files['file']
if file_upload:
image = {}
image['filename'] = secure_filename(file_upload.filename)
full_path = os.path.join(session['img_input_dir'],
image['filename'])
file_upload.save(full_path)
image['uid'] = session['image_uid_counter']
session['image_uid_counter'] += 1
current_app.logger.debug('File %d is saved as %s',
image['uid'],
image['filename'])
session['image_list'].append(image)
return jsonify(ok="true", file=image['filename'], uid=image['uid'])
return jsonify(ok="false")
if request.method == 'GET':
return jsonify(images=session['image_list'])


@API.route('/visualize', methods=['GET'])
def visualize():
"""Trigger a visualization via the REST API
Takes a single image and generates the visualization data, returning the
output exactly as given by the target visualization.
"""

session['settings'] = {}
image_uid = request.args.get('image')
vis_name = request.args.get('visualizer')
vis = get_visualizations()[vis_name]
if hasattr(vis, 'settings'):
for key in vis.settings.keys():
if request.args.get(key) is not None:
session['settings'][key] = request.args.get(key)
else:
session['settings'][key] = vis.settings[key][0]
inputs = []
for image in session['image_list']:
if image['uid'] == int(image_uid):
full_path = os.path.join(session['img_input_dir'],
image['filename'])
entry = {}
entry['filename'] = image['filename']
entry['data'] = Image.open(full_path)
inputs.append(entry)
if 'settings' in session:
vis.update_settings(session['settings'])
output = vis.make_visualization(
inputs, output_dir=session['img_output_dir'])
return jsonify(output=output)


@API.route('/reset', methods=['GET'])
def reset():
"""Delete the session and clear temporary directories
"""
shutil.rmtree(session['img_input_dir'])
shutil.rmtree(session['img_output_dir'])
session.clear()
return jsonify(ok='true')
180 changes: 5 additions & 175 deletions picasso/picasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,60 +25,28 @@
Visualization classes available for rendering.
"""
from importlib import import_module
import inspect
import io
import os
from operator import itemgetter
import shutil
from tempfile import mkdtemp
import time
from types import ModuleType

from PIL import Image
from flask import (
g,
render_template,
request,
session,
send_from_directory,
jsonify
send_from_directory
)
from werkzeug.utils import secure_filename

from picasso import __version__
from picasso import app
from picasso.models.base import load_model
from picasso.visualizations import *
from picasso.visualizations.base import BaseVisualization

APP_TITLE = 'Picasso Visualizer'


def _get_visualization_classes():
"""Import visualizations classes dynamically
from picasso.utils import (
get_app_state,
get_visualizations
)

"""
visualization_attr = vars(
import_module('picasso.visualizations'))
visualization_submodules = [
visualization_attr[x]
for x in visualization_attr
if isinstance(visualization_attr[x], ModuleType)]

visualization_classes = []
for submodule in visualization_submodules:
attrs = vars(submodule)
for attr_name in attrs:
attr = attrs[attr_name]
if (inspect.isclass(attr)
and issubclass(attr, BaseVisualization)
and attr is not BaseVisualization):
visualization_classes.append(attr)
return visualization_classes


VISUALIZATION_CLASSES = _get_visualization_classes()

# Use a bogus secret key for debugging ease. No client information is stored;
# the secret key is only necessary for generating the session cookie.
Expand Down Expand Up @@ -117,144 +85,6 @@ def initialize_new_session():
session['img_output_dir'] = mkdtemp()


def get_model():
"""Get the NN model that's being analyzed from the request context. Put
the model in the request context if it is not yet there.
Returns:
instance of :class:`.models.model.Model` or derived
class
"""
if not hasattr(g, 'model'):
g.model = model
return g.model


def get_visualizations():
"""Get the available visualizations from the request context. Put the
visualizations in the request context if they are not yet there.
Returns:
:obj:`list` of instances of :class:`.BaseVisualization` or
derived class
"""
if not hasattr(g, 'visualizations'):
g.visualizations = {}
for VisClass in VISUALIZATION_CLASSES:
vis = VisClass(get_model())
g.visualizations[vis.__class__.__name__] = vis

return g.visualizations


def get_app_state():
"""Get current status of application in context
Returns:
:obj:`dict` of application status
"""
if not hasattr(g, 'app_state'):
model = get_model()
g.app_state = {
'app_title': APP_TITLE,
'model_name': type(model).__name__,
'latest_ckpt_name': model.latest_ckpt_name,
'latest_ckpt_time': model.latest_ckpt_time
}
return g.app_state


@app.route('/api/', methods=['GET'])
def api_root():
"""The root of the REST API
displays a hello world message.
"""
return jsonify(message='Picasso {version}. '
'See API documentation at: '
'https://picasso.readthedocs.io/en/latest/api.html'
.format(version=__version__),
version=__version__)


@app.route('/api/images', methods=['POST', 'GET'])
def api_images():
"""Upload images via REST interface
Check if file upload was successful and sanatize user input.
TODO: return file URL instead of filename
"""
if request.method == 'POST':
file_upload = request.files['file']
if file_upload:
image = {}
image['filename'] = secure_filename(file_upload.filename)
full_path = os.path.join(session['img_input_dir'],
image['filename'])
file_upload.save(full_path)
image['uid'] = session['image_uid_counter']
session['image_uid_counter'] += 1
app.logger.debug('File %d is saved as %s',
image['uid'],
image['filename'])
session['image_list'].append(image)
return jsonify(ok="true", file=image['filename'], uid=image['uid'])
return jsonify(ok="false")
if request.method == 'GET':
return jsonify(images=session['image_list'])


@app.route('/api/visualize', methods=['GET'])
def api_visualize():
"""Trigger a visualization via the REST API
Takes a single image and generates the visualization data, returning the
output exactly as given by the target visualization.
"""

session['settings'] = {}
image_uid = request.args.get('image')
vis_name = request.args.get('visualizer')
vis = get_visualizations()[vis_name]
if hasattr(vis, 'settings'):
for key in vis.settings.keys():
if request.args.get(key) is not None:
session['settings'][key] = request.args.get(key)
else:
session['settings'][key] = vis.settings[key][0]
inputs = []
for image in session['image_list']:
if image['uid'] == int(image_uid):
full_path = os.path.join(session['img_input_dir'],
image['filename'])
entry = {}
entry['filename'] = image['filename']
entry['data'] = Image.open(full_path)
inputs.append(entry)
if 'settings' in session:
vis.update_settings(session['settings'])
output = vis.make_visualization(inputs,
output_dir=session['img_output_dir'])
return jsonify(output=output)


@app.route('/api/reset', methods=['GET'])
def end_session():
"""Delete the session and clear temporary directories
"""
shutil.rmtree(session['img_input_dir'])
shutil.rmtree(session['img_output_dir'])
session.clear()
return jsonify(ok='true')


@app.route('/', methods=['GET', 'POST'])
def landing():
"""Landing page for the application
Expand Down

0 comments on commit bda2c3d

Please sign in to comment.