From c31b58276fb6a717f7286d1d11be288180327081 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Thu, 12 May 2016 14:40:18 +0200 Subject: [PATCH 1/2] View extensions --- digits/extensions/__init__.py | 1 + digits/extensions/view/__init__.py | 35 +++++ digits/extensions/view/interface.py | 92 +++++++++++++ digits/extensions/view/rawData/__init__.py | 4 + .../view/rawData/config_template.html | 7 + digits/extensions/view/rawData/forms.py | 10 ++ digits/extensions/view/rawData/view.py | 85 ++++++++++++ .../view/rawData/view_template.html | 10 ++ digits/model/images/generic/views.py | 130 ++++++++++++++---- digits/model/views.py | 20 ++- .../models/images/generic/infer_db.html | 48 ++++--- .../models/images/generic/infer_many.html | 46 ++++--- .../models/images/generic/infer_one.html | 25 ++-- .../templates/models/images/generic/show.html | 37 +++++ 14 files changed, 468 insertions(+), 82 deletions(-) create mode 100644 digits/extensions/view/__init__.py create mode 100644 digits/extensions/view/interface.py create mode 100644 digits/extensions/view/rawData/__init__.py create mode 100644 digits/extensions/view/rawData/config_template.html create mode 100644 digits/extensions/view/rawData/forms.py create mode 100644 digits/extensions/view/rawData/view.py create mode 100644 digits/extensions/view/rawData/view_template.html diff --git a/digits/extensions/__init__.py b/digits/extensions/__init__.py index 65b8532bf..a3860f58d 100644 --- a/digits/extensions/__init__.py +++ b/digits/extensions/__init__.py @@ -2,3 +2,4 @@ from __future__ import absolute_import from .data import * +from .view import * diff --git a/digits/extensions/view/__init__.py b/digits/extensions/view/__init__.py new file mode 100644 index 000000000..a5ebb3573 --- /dev/null +++ b/digits/extensions/view/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from . import rawData + +view_extensions = [ + # set show=True if extension should be listed in known extensions + {'class': rawData.Visualization, 'show': True}, +] + + +def get_default_extension(): + """ + return the default view extension + """ + return rawData.Visualization + + +def get_extensions(): + """ + return set of data data extensions + """ + return [extension['class'] for extension + in view_extensions if extension['show']] + + +def get_extension(extension_id): + """ + return extension associated with specified extension_id + """ + for extension in view_extensions: + extension_class = extension['class'] + if extension_class.get_id() == extension_id: + return extension_class + return None diff --git a/digits/extensions/view/interface.py b/digits/extensions/view/interface.py new file mode 100644 index 000000000..9d86c18be --- /dev/null +++ b/digits/extensions/view/interface.py @@ -0,0 +1,92 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + + +class VisualizationInterface(object): + """ + A visualization extension + """ + + def __init__(self, **kwargs): + pass + + @staticmethod + def get_config_form(): + """ + Return a form to be used to configure visualization options + """ + raise NotImplementedError + + @staticmethod + def get_config_template(form): + """ + The config template shows a form with view config options + Parameters: + - form: form returned by get_config_form(). This may be populated + with values if the job was cloned + Returns: + - (template, context) tuple + - template is a Jinja template to use for rendering config options + - context is a dictionary of context variables to use for rendering + the form + """ + raise NotImplementedError + + @staticmethod + def get_id(): + """ + Returns a unique ID + """ + raise NotImplementedError + + def get_summary_template(self): + """ + This returns a summary of the job. This method is called after all + entries have been processed. + Returns: + - (template, context) tuple + - template is a Jinja template to use for rendering the summary, + or None if there is no summary to display + - context is a dictionary of context variables to use for rendering + the form + """ + return None, None + + @staticmethod + def get_title(): + """ + Returns a title + """ + raise NotImplementedError + + def get_view_template(self, data): + """ + The view template shows the visualization of one inference output + Parameters: + - data: the data returned by process_data() + Returns: + - (template, context) tuple + - template is a Jinja template to use for rendering config options + - context is a dictionary of context variables to use for rendering + the form + """ + raise NotImplementedError + + def process_data( + self, + dataset, + input_data, + inference_data, + ground_truth=None): + """ + Process one inference output + Parameters: + - dataset: dataset used during training + - input_data: input to the network + - inference_data: network output + - ground_truth: Ground truth. Format is application specific. + None if absent. + Returns: + - an object reprensenting the processed data + """ + raise NotImplementedError diff --git a/digits/extensions/view/rawData/__init__.py b/digits/extensions/view/rawData/__init__.py new file mode 100644 index 000000000..c0b2b8cf3 --- /dev/null +++ b/digits/extensions/view/rawData/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from .view import Visualization diff --git a/digits/extensions/view/rawData/config_template.html b/digits/extensions/view/rawData/config_template.html new file mode 100644 index 000000000..b84e40954 --- /dev/null +++ b/digits/extensions/view/rawData/config_template.html @@ -0,0 +1,7 @@ +{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} + +{% from "helper.html" import print_flashes %} +{% from "helper.html" import print_errors %} +{% from "helper.html" import mark_errors %} + +This visualization has no configuration options diff --git a/digits/extensions/view/rawData/forms.py b/digits/extensions/view/rawData/forms.py new file mode 100644 index 000000000..d4489c217 --- /dev/null +++ b/digits/extensions/view/rawData/forms.py @@ -0,0 +1,10 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from digits.utils import subclass +from flask.ext.wtf import Form + + +@subclass +class ConfigForm(Form): + pass diff --git a/digits/extensions/view/rawData/view.py b/digits/extensions/view/rawData/view.py new file mode 100644 index 000000000..8647b8c27 --- /dev/null +++ b/digits/extensions/view/rawData/view.py @@ -0,0 +1,85 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +import os + +from digits.utils import subclass, override +from .forms import ConfigForm +from ..interface import VisualizationInterface + +CONFIG_TEMPLATE = "config_template.html" +VIEW_TEMPLATE = "view_template.html" + + +@subclass +class Visualization(VisualizationInterface): + """ + A visualization extension to display raw data + """ + + def __init__(self, dataset, **kwargs): + # memorize view template for later use + extension_dir = os.path.dirname(os.path.abspath(__file__)) + self.view_template = open( + os.path.join(extension_dir, VIEW_TEMPLATE), "r").read() + + @staticmethod + def get_config_form(): + return ConfigForm() + + @staticmethod + def get_config_template(form): + """ + parameters: + - form: form returned by get_config_form(). This may be populated + with values if the job was cloned + return: + - (template, context) tuple + - template is a Jinja template to use for rendering config options + - context is a dictionary of context variables to use for rendering + the form + """ + extension_dir = os.path.dirname(os.path.abspath(__file__)) + template = open( + os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read() + return (template, {}) + + @staticmethod + def get_id(): + return 'all-raw-data' + + @staticmethod + def get_title(): + return 'Raw Data' + + @override + def get_view_template(self, data): + """ + return: + - (template, context) tuple + - template is a Jinja template to use for rendering config options + - context is a dictionary of context variables to use for rendering + the form + """ + return self.view_template, {'data': data} + + @override + def process_data( + self, + dataset, + input_data, + inference_data, + ground_truth=None): + """ + Process one inference output + Parameters: + - dataset: dataset used during training + - input_data: input to the network + - inference_data: network output + - ground_truth: Ground truth. Format is application specific. + None if absent. + Returns: + - an object reprensenting the processed data + """ + # just return the same data and ignore ground truth + return inference_data diff --git a/digits/extensions/view/rawData/view_template.html b/digits/extensions/view/rawData/view_template.html new file mode 100644 index 000000000..cf010623e --- /dev/null +++ b/digits/extensions/view/rawData/view_template.html @@ -0,0 +1,10 @@ +{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} + +{% from "helper.html" import print_flashes %} +{% from "helper.html" import print_errors %} +{% from "helper.html" import mark_errors %} + +{% for key, val in data.iteritems() %} + {{key}} + {{data[key]}} +{% endfor %} diff --git a/digits/model/images/generic/views.py b/digits/model/images/generic/views.py index 9d631f175..3f40281d4 100644 --- a/digits/model/images/generic/views.py +++ b/digits/model/images/generic/views.py @@ -257,11 +257,18 @@ def create(extension_id=None): # If there are multiple jobs launched, go to the home page. return flask.redirect('/') + def show(job, related_jobs=None): """ Called from digits.model.views.models_show() """ - return flask.render_template('models/images/generic/show.html', job=job, related_jobs=related_jobs) + view_extensions = get_view_extensions() + return flask.render_template( + 'models/images/generic/show.html', + job=job, + view_extensions=view_extensions, + related_jobs=related_jobs) + @blueprint.route('/large_graph', methods=['GET']) def large_graph(): @@ -317,7 +324,7 @@ def infer_one(): inference_job.wait_completion() # retrieve inference data - inputs, outputs, visualizations = inference_job.get_data() + inputs, outputs, model_visualization = inference_job.get_data() # set return status code status_code = 500 if inference_job.status == 'E' else 200 @@ -329,20 +336,30 @@ def infer_one(): os.remove(image_path) image = None + inference_view_html = None if inputs is not None and len(inputs['data']) == 1: image = utils.image.embed_image_html(inputs['data'][0]) + visualizations, summary = get_inference_visualizations( + model_job.dataset, + inputs, + outputs) + inference_view_html = visualizations[0] if request_wants_json(): - return flask.jsonify({'outputs': dict((name, blob.tolist()) for name,blob in outputs.iteritems())}), status_code + return flask.jsonify({'outputs': dict((name, blob.tolist()) + for name, blob in outputs.iteritems())}), status_code else: - return flask.render_template('models/images/generic/infer_one.html', - model_job = model_job, - job = inference_job, - image_src = image, - network_outputs = outputs, - visualizations = visualizations, - total_parameters= sum(v['param_count'] for v in visualizations if v['vis_type'] == 'Weights'), - ), status_code + return flask.render_template( + 'models/images/generic/infer_one.html', + model_job=model_job, + job=inference_job, + image_src=image, + inference_view_html=inference_view_html, + visualizations=model_visualization, + total_parameters=sum(v['param_count'] for v in model_visualization + if v['vis_type'] == 'Weights'), + ), status_code + @blueprint.route('/infer_db.json', methods=['POST']) @blueprint.route('/infer_db', methods=['POST', 'GET']) @@ -393,8 +410,13 @@ def infer_db(): # an error occurred outputs = None + inference_views_html = None if inputs is not None: keys = [str(idx) for idx in inputs['ids']] + inference_views_html, summary_html = get_inference_visualizations( + model_job.dataset, + inputs, + outputs) else: keys = None @@ -404,12 +426,15 @@ def infer_db(): result[key] = dict((name, blob[i].tolist()) for name,blob in outputs.iteritems()) return flask.jsonify({'outputs': result}), status_code else: - return flask.render_template('models/images/generic/infer_db.html', - model_job = model_job, - job = inference_job, - keys = keys, - network_outputs = outputs, - ), status_code + return flask.render_template( + 'models/images/generic/infer_db.html', + model_job=model_job, + job=inference_job, + keys=keys, + inference_views_html=inference_views_html, + summary_html=summary_html, + ), status_code + @blueprint.route('/infer_many.json', methods=['POST']) @blueprint.route('/infer_many', methods=['POST', 'GET']) @@ -490,21 +515,28 @@ def infer_many(): # an error occurred outputs = None + inference_views_html = None if inputs is not None: paths = [paths[idx] for idx in inputs['ids']] + inference_views_html, summary_html = get_inference_visualizations( + model_job.dataset, + inputs, + outputs) if request_wants_json(): result = {} for i, path in enumerate(paths): - result[path] = dict((name, blob[i].tolist()) for name,blob in outputs.iteritems()) + result[path] = dict((name, blob[i].tolist()) for name, blob in outputs.iteritems()) return flask.jsonify({'outputs': result}), status_code else: - return flask.render_template('models/images/generic/infer_many.html', - model_job = model_job, - job = inference_job, - paths = paths, - network_outputs = outputs, - ), status_code + return flask.render_template( + 'models/images/generic/infer_many.html', + model_job=model_job, + job=inference_job, + paths=paths, + inference_views_html=inference_views_html, + summary_html=summary_html, + ), status_code def get_datasets(extension_id): @@ -521,6 +553,46 @@ def get_datasets(extension_id): for j in sorted(jobs, cmp=lambda x, y: cmp(y.id(), x.id()))] +def get_inference_visualizations(dataset, inputs, outputs): + # get extension ID from form and retrieve extension class + if 'view_extension_id' in flask.request.form: + view_extension_id = flask.request.form['view_extension_id'] + extension_class = extensions.view.get_extension(view_extension_id) + if extension_class is None: + raise ValueError("Unknown extension '%s'" % view_extension_id) + else: + # no view extension specified, use default + extension_class = extensions.view.get_default_extension() + extension_form = extension_class.get_config_form() + + # validate form + extension_form_valid = extension_form.validate_on_submit() + if not extension_form_valid: + raise ValueError("Extension form validation failed with %s" % repr(extension_form.errors)) + + # create instance of extension class + extension = extension_class(dataset, **extension_form.data) + + visualizations = [] + # process data + n = len(inputs['ids']) + for idx in xrange(n): + input_id = inputs['ids'][idx] + input_data = inputs['data'][idx] + output_data = {key: outputs[key][idx] for key in outputs} + data = extension.process_data( + input_id, + input_data, + output_data) + template, context = extension.get_view_template(data) + visualizations.append( + flask.render_template_string(template, **context)) + # get summary + template, context = extension.get_summary_template() + summary = flask.render_template_string(template, **context) if template else None + return visualizations, summary + + def get_previous_networks(): return [(j.id(), j.name()) for j in sorted( [j for j in scheduler.jobs.values() if isinstance(j, GenericImageModelJob)], @@ -546,3 +618,13 @@ def get_previous_network_snapshots(): prev_network_snapshots.append(e) return prev_network_snapshots + +def get_view_extensions(): + """ + return all enabled view extensions + """ + view_extensions = {} + all_extensions = extensions.view.get_extensions() + for extension in all_extensions: + view_extensions[extension.get_id()] = extension.get_title() + return view_extensions diff --git a/digits/model/views.py b/digits/model/views.py index 3ef854ec7..f5eb4dc1b 100644 --- a/digits/model/views.py +++ b/digits/model/views.py @@ -11,14 +11,12 @@ import flask import werkzeug.exceptions -from . import forms from . import images as model_images from . import ModelJob -import digits -from digits import frameworks +from digits import frameworks, extensions from digits.utils import time_filters from digits.utils.routing import request_wants_json -from digits.webapp import app, scheduler +from digits.webapp import scheduler blueprint = flask.Blueprint(__name__, __name__) @@ -142,6 +140,20 @@ def customize(): 'snapshot': snapshot }) + +@blueprint.route('/view-config/', methods=['GET']) +def view_config(extension_id): + """ + Returns a rendering of a view extension configuration template + """ + extension = extensions.view.get_extension(extension_id) + if extension is None: + raise ValueError("Unknown extension '%s'" % extension_id) + config_form = extension.get_config_form() + template, context = extension.get_config_template(config_form) + return flask.render_template_string(template, **context) + + @blueprint.route('/visualize-network', methods=['POST']) def visualize_network(): """ diff --git a/digits/templates/models/images/generic/infer_db.html b/digits/templates/models/images/generic/infer_db.html index 52e667291..25f6b384e 100644 --- a/digits/templates/models/images/generic/infer_db.html +++ b/digits/templates/models/images/generic/infer_db.html @@ -17,7 +17,7 @@

-{% if not network_outputs %} +{% if not keys %}

Inference failed, see job log

@@ -27,26 +27,34 @@

{% block job_content_details %} -{% if network_outputs %} - - - - - {% for key in network_outputs.keys() %} - - {% endfor %} - - {% for key in keys %} - - - - {% set index=loop.index0 %} - {% for key, val in network_outputs.iteritems() %} - +{% if summary_html %} +
+

Summary

+ {{summary_html|safe}} +
+{% endif %} + +{% if inference_views_html %} +
+

Visualizations

+
IndexKey{{key}}
{{loop.index}}{{key}}{{network_outputs[key][index]}}
+ + + + + + {% for key in keys %} + + + + + {% endfor %} - - {% endfor %} -
IndexKeyData
{{loop.index}}{{key}} + {% set index=loop.index0 %} + {{ inference_views_html[index]|safe }} +
+ + {% endif %} {% endblock %} diff --git a/digits/templates/models/images/generic/infer_many.html b/digits/templates/models/images/generic/infer_many.html index 6ff5a1ccc..cf0a5d38d 100644 --- a/digits/templates/models/images/generic/infer_many.html +++ b/digits/templates/models/images/generic/infer_many.html @@ -17,7 +17,7 @@

-{% if not network_outputs %} +{% if not paths %}

Inference failed, see job log

@@ -27,26 +27,32 @@

{% block job_content_details %} -{% if network_outputs %} - - - - - {% for key in network_outputs.keys() %} - - {% endfor %} - - {% for path in paths %} - - - - {% set index=loop.index0 %} - {% for key, val in network_outputs.iteritems() %} - +{% if summary_html %} +
+

Summary

+ {{summary_html|safe}} +
+{% endif %} + +{% if paths %} +
+

Visualizations

+
Image{{key}}
{{loop.index}}{{path}}{{network_outputs[key][index]}}
+ + + + + + {% for path in paths %} + + + + {% set index=loop.index0 %} + + {% endfor %} - - {% endfor %} -
ImageData
{{loop.index}}{{path}}{{inference_views_html[index]|safe}}
+ + {% endif %} {% endblock %} diff --git a/digits/templates/models/images/generic/infer_one.html b/digits/templates/models/images/generic/infer_one.html index b656563b1..39fabde6d 100644 --- a/digits/templates/models/images/generic/infer_one.html +++ b/digits/templates/models/images/generic/infer_one.html @@ -19,25 +19,22 @@

- +{% if image_src %} +
+

Source image

+ +
+
+

Inference visualization

+ {{ inference_view_html|safe }} +
+{% else %}
- {% if image_src %} -
- -
- -
- {% for name, blob in network_outputs.iteritems() %} -

{{name}}

- {{blob}} - {% endfor %} -
- {% else %}

Inference failed, see job log

- {% endif %}
+{% endif %} {% endblock %} diff --git a/digits/templates/models/images/generic/show.html b/digits/templates/models/images/generic/show.html index 2de9a1206..49dd64550 100644 --- a/digits/templates/models/images/generic/show.html +++ b/digits/templates/models/images/generic/show.html @@ -144,6 +144,43 @@

Trained Models

+
+
+

Select Visualization Method

+
+ +
+
+
+

Visualization Options

+
+
+
+ +
+

Test a single image

From 7d8ff7b6fbe7621a7cc8af002616e727b2f03be5 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Thu, 19 May 2016 09:28:03 +0200 Subject: [PATCH 2/2] Add tests for view extensions --- digits/model/images/generic/test_views.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/digits/model/images/generic/test_views.py b/digits/model/images/generic/test_views.py index ccfb3becf..88aa32e9c 100644 --- a/digits/model/images/generic/test_views.py +++ b/digits/model/images/generic/test_views.py @@ -23,6 +23,7 @@ import PIL.Image from urlparse import urlparse +from digits import extensions from digits.config import config_value import digits.dataset.images.generic.test_views import digits.test_views @@ -240,6 +241,11 @@ def test_page_model_new(self): def test_nonexistent_model(self): assert not self.model_exists('foo'), "model shouldn't exist" + def test_view_config(self): + extension = extensions.view.get_default_extension() + rv = self.app.get('/models/view-config/%s' % extension.get_id()) + assert rv.status_code == 200, 'page load failed with %s' % rv.status_code + def test_visualize_network(self): rv = self.app.post('/models/visualize-network?framework='+self.FRAMEWORK, data = {'custom_network': self.network()} @@ -595,9 +601,15 @@ def test_infer_many_from_folder(self): # StringIO wrapping is needed to simulate POST file upload. file_upload = (StringIO(textfile_images), 'images.txt') + # try selecting the extension explicitly + extension = extensions.view.get_default_extension() + extension_id = extension.get_id() + rv = self.app.post( '/models/images/generic/infer_many?job_id=%s' % self.model_id, - data = {'image_list': file_upload, 'image_folder': os.path.dirname(self.test_image)} + data = {'image_list': file_upload, + 'image_folder': os.path.dirname(self.test_image), + 'view_extension_id': extension_id} ) s = BeautifulSoup(rv.data, 'html.parser') body = s.select('body')