Skip to content

Commit

Permalink
Merge pull request NVIDIA#756 from gheinrich/dev/view-extensions
Browse files Browse the repository at this point in the history
View extension framework
  • Loading branch information
lukeyeager committed May 25, 2016
2 parents d505a86 + 7d8ff7b commit 29063d8
Show file tree
Hide file tree
Showing 15 changed files with 481 additions and 83 deletions.
1 change: 1 addition & 0 deletions digits/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from __future__ import absolute_import

from .data import *
from .view import *
35 changes: 35 additions & 0 deletions digits/extensions/view/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from . import rawData

view_extensions = [
# set show=True if extension should be listed in known extensions
{'class': rawData.Visualization, 'show': True},
]


def get_default_extension():
"""
return the default view extension
"""
return rawData.Visualization


def get_extensions():
"""
return set of data data extensions
"""
return [extension['class'] for extension
in view_extensions if extension['show']]


def get_extension(extension_id):
"""
return extension associated with specified extension_id
"""
for extension in view_extensions:
extension_class = extension['class']
if extension_class.get_id() == extension_id:
return extension_class
return None
92 changes: 92 additions & 0 deletions digits/extensions/view/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import


class VisualizationInterface(object):
"""
A visualization extension
"""

def __init__(self, **kwargs):
pass

@staticmethod
def get_config_form():
"""
Return a form to be used to configure visualization options
"""
raise NotImplementedError

@staticmethod
def get_config_template(form):
"""
The config template shows a form with view config options
Parameters:
- form: form returned by get_config_form(). This may be populated
with values if the job was cloned
Returns:
- (template, context) tuple
- template is a Jinja template to use for rendering config options
- context is a dictionary of context variables to use for rendering
the form
"""
raise NotImplementedError

@staticmethod
def get_id():
"""
Returns a unique ID
"""
raise NotImplementedError

def get_summary_template(self):
"""
This returns a summary of the job. This method is called after all
entries have been processed.
Returns:
- (template, context) tuple
- template is a Jinja template to use for rendering the summary,
or None if there is no summary to display
- context is a dictionary of context variables to use for rendering
the form
"""
return None, None

@staticmethod
def get_title():
"""
Returns a title
"""
raise NotImplementedError

def get_view_template(self, data):
"""
The view template shows the visualization of one inference output
Parameters:
- data: the data returned by process_data()
Returns:
- (template, context) tuple
- template is a Jinja template to use for rendering config options
- context is a dictionary of context variables to use for rendering
the form
"""
raise NotImplementedError

def process_data(
self,
dataset,
input_data,
inference_data,
ground_truth=None):
"""
Process one inference output
Parameters:
- dataset: dataset used during training
- input_data: input to the network
- inference_data: network output
- ground_truth: Ground truth. Format is application specific.
None if absent.
Returns:
- an object reprensenting the processed data
"""
raise NotImplementedError
4 changes: 4 additions & 0 deletions digits/extensions/view/rawData/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from .view import Visualization
7 changes: 7 additions & 0 deletions digits/extensions/view/rawData/config_template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

<small>This visualization has no configuration options</small>
10 changes: 10 additions & 0 deletions digits/extensions/view/rawData/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from digits.utils import subclass
from flask.ext.wtf import Form


@subclass
class ConfigForm(Form):
pass
85 changes: 85 additions & 0 deletions digits/extensions/view/rawData/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os

from digits.utils import subclass, override
from .forms import ConfigForm
from ..interface import VisualizationInterface

CONFIG_TEMPLATE = "config_template.html"
VIEW_TEMPLATE = "view_template.html"


@subclass
class Visualization(VisualizationInterface):
"""
A visualization extension to display raw data
"""

def __init__(self, dataset, **kwargs):
# memorize view template for later use
extension_dir = os.path.dirname(os.path.abspath(__file__))
self.view_template = open(
os.path.join(extension_dir, VIEW_TEMPLATE), "r").read()

@staticmethod
def get_config_form():
return ConfigForm()

@staticmethod
def get_config_template(form):
"""
parameters:
- form: form returned by get_config_form(). This may be populated
with values if the job was cloned
return:
- (template, context) tuple
- template is a Jinja template to use for rendering config options
- context is a dictionary of context variables to use for rendering
the form
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(
os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read()
return (template, {})

@staticmethod
def get_id():
return 'all-raw-data'

@staticmethod
def get_title():
return 'Raw Data'

@override
def get_view_template(self, data):
"""
return:
- (template, context) tuple
- template is a Jinja template to use for rendering config options
- context is a dictionary of context variables to use for rendering
the form
"""
return self.view_template, {'data': data}

@override
def process_data(
self,
dataset,
input_data,
inference_data,
ground_truth=None):
"""
Process one inference output
Parameters:
- dataset: dataset used during training
- input_data: input to the network
- inference_data: network output
- ground_truth: Ground truth. Format is application specific.
None if absent.
Returns:
- an object reprensenting the processed data
"""
# just return the same data and ignore ground truth
return inference_data
10 changes: 10 additions & 0 deletions digits/extensions/view/rawData/view_template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

{% for key, val in data.iteritems() %}
<td>{{key}}</td>
<td>{{data[key]}}</td>
{% endfor %}
14 changes: 13 additions & 1 deletion digits/model/images/generic/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import PIL.Image
from urlparse import urlparse

from digits import extensions
from digits.config import config_value
import digits.dataset.images.generic.test_views
import digits.test_views
Expand Down Expand Up @@ -240,6 +241,11 @@ def test_page_model_new(self):
def test_nonexistent_model(self):
assert not self.model_exists('foo'), "model shouldn't exist"

def test_view_config(self):
extension = extensions.view.get_default_extension()
rv = self.app.get('/models/view-config/%s' % extension.get_id())
assert rv.status_code == 200, 'page load failed with %s' % rv.status_code

def test_visualize_network(self):
rv = self.app.post('/models/visualize-network?framework='+self.FRAMEWORK,
data = {'custom_network': self.network()}
Expand Down Expand Up @@ -595,9 +601,15 @@ def test_infer_many_from_folder(self):
# StringIO wrapping is needed to simulate POST file upload.
file_upload = (StringIO(textfile_images), 'images.txt')

# try selecting the extension explicitly
extension = extensions.view.get_default_extension()
extension_id = extension.get_id()

rv = self.app.post(
'/models/images/generic/infer_many?job_id=%s' % self.model_id,
data = {'image_list': file_upload, 'image_folder': os.path.dirname(self.test_image)}
data = {'image_list': file_upload,
'image_folder': os.path.dirname(self.test_image),
'view_extension_id': extension_id}
)
s = BeautifulSoup(rv.data, 'html.parser')
body = s.select('body')
Expand Down
Loading

0 comments on commit 29063d8

Please sign in to comment.