Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAN support for DIGITS #3

Merged
merged 49 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6c81a79
bAbI data plug-in
gheinrich Oct 23, 2016
9372b8e
Tensorflow integration updates
gheinrich Nov 4, 2016
c5b397e
Add gradient hook
gheinrich Dec 11, 2016
d15ae65
Add memn2n model
gheinrich Dec 11, 2016
a0a262f
Update memn2n with gradient hooks
gheinrich Dec 20, 2016
8cdd4ff
GAN example
gheinrich Jan 16, 2017
ec108d1
Make batch size variable
gheinrich Jan 16, 2017
630e440
Training/inference paths
gheinrich Jan 16, 2017
3279f31
Small update to TF 0.12
gheinrich Jan 17, 2017
dc5c261
Snapshot names, float inference, restore all vars
gheinrich Jan 17, 2017
bf0fe27
Do not restore global_step or optimizer variables
gheinrich Jan 21, 2017
87bd27e
Add TB link
gheinrich Jan 20, 2017
5704612
Update GAN network
gheinrich Jan 21, 2017
0d4c07c
Dynamically select inference form
gheinrich Jan 21, 2017
f31a8e5
TF inference: convert images to float
gheinrich Jan 21, 2017
4766bc2
Update GAN z-gen network
gheinrich Jan 22, 2017
f09a7ac
Small Update model view layout
gheinrich Jan 22, 2017
b1dc6dd
Add GAN plug-ins
gheinrich Jan 22, 2017
1992060
Update GAN plug-in to create CelebA dataset
gheinrich Jan 26, 2017
af3fc4b
Add ability to show input in ImageOutput extension
gheinrich Feb 1, 2017
bdc6f99
Add all data to raw data view extension
gheinrich Feb 4, 2017
7f39aba
Add model for CelebA dataset
gheinrich Feb 4, 2017
d6c4a2f
Update GAN data plug-in
gheinrich Feb 4, 2017
9baa279
Update all losses in one session
gheinrich Feb 4, 2017
661c2d7
Remove conversion to .png in GAN data plug-in
gheinrich Feb 8, 2017
abf7691
TF Slim Lenet example
gheinrich Feb 13, 2017
810f694
Update GAN data plug-in
gheinrich Feb 13, 2017
8bcd27a
Fix TF model snapshot
gheinrich Feb 13, 2017
a15e8e9
Reduce scheduler delays to speed up inference
gheinrich Feb 13, 2017
890a459
Update GAN plugins
gheinrich Feb 14, 2017
a1fd74b
Fix TF tests
gheinrich Feb 14, 2017
2a34bae
Add API to LmdbReader (used by gan_features.py)
gheinrich Feb 14, 2017
1df6097
Save animated gif
gheinrich Feb 14, 2017
0d90fb3
Add GAN walk-through
gheinrich Feb 14, 2017
8ea8a5b
Update GAN walkthrough with embeddings video
gheinrich Feb 15, 2017
cc08a8d
Fix GAN view for list encoding
gheinrich Feb 16, 2017
7941479
Add animation task to GAN plugins
gheinrich Feb 24, 2017
df5cfd9
Add view task to see image attributes
gheinrich Mar 2, 2017
eed9ba4
Add comments to GAN models
gheinrich Mar 14, 2017
0f61aee
Update README
gheinrich Mar 15, 2017
9fc5f0b
Fix GAN features script
gheinrich Mar 16, 2017
195abe9
GAN app
gheinrich Mar 21, 2017
75d2621
Fix DIGITS inference
gheinrich Mar 23, 2017
485e96a
Adjust GAN window size automatically
gheinrich Mar 23, 2017
ba2fc56
Add attributes to GAN app
gheinrich Mar 23, 2017
91d808d
Move gandisplay.py
gheinrich Mar 23, 2017
9a0d968
Remove wxpython 3.0 selection
gheinrich Mar 23, 2017
b645cd1
Fix call to model
gheinrich Mar 23, 2017
4326b01
Adding disclaimer
gheinrich Apr 5, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion digits/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ def config_value(option):
Return the current configuration value for the given option
"""
return option_list[option]

9 changes: 5 additions & 4 deletions digits/config/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,27 @@

import os
import platform
from subprocess import Popen,PIPE
from subprocess import Popen, PIPE

from . import option_list

VARNAME_ENV_TFPY = 'TENSORFLOW_PYTHON'
DEFAULT_PYTHON_EXE = 'python2' # @TODO(tzaman) - use the python executable that was used to launch digits?
DEFAULT_PYTHON_EXE = 'python2' # @TODO(tzaman) - use the python executable that was used to launch digits?

if platform.system() == 'Darwin':
# DYLD_LIBRARY_PATH and LD_LIBRARY_PATH is sometimes stripped, and the cuda libraries might need it
if not "DYLD_LIBRARY_PATH" in os.environ:
if "DYLD_LIBRARY_PATH" not in os.environ:
if "CUDA_HOME" in os.environ:
os.environ["DYLD_LIBRARY_PATH"] = str(os.environ["CUDA_HOME"] + '/lib')


def test_tf_import(python_exe):
"""
Tests if tensorflow can be imported, returns if it went okay and optional error.
"""
p = Popen([python_exe, "-c", "import tensorflow"], stdout=PIPE, stderr=PIPE)
(out, err) = p.communicate()
return p.returncode==0, str(err)
return p.returncode == 0, str(err)

if VARNAME_ENV_TFPY in os.environ:
tf_python_exe = os.environ[VARNAME_ENV_TFPY]
Expand Down
29 changes: 29 additions & 0 deletions digits/dataset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import images as dataset_images
from . import generic
from digits import extensions
from digits.utils.routing import job_from_request, request_wants_json
from digits.webapp import scheduler

Expand Down Expand Up @@ -54,3 +55,31 @@ def summary():
return generic.views.summary(job)
else:
raise werkzeug.exceptions.BadRequest('Invalid job type')


@blueprint.route('/inference-form/<extension_id>/<job_id>', methods=['GET'])
def inference_form(extension_id, job_id):
"""
Returns a rendering of an inference form
"""
inference_form_html = ""

if extension_id != "all-default":
extension_class = extensions.data.get_extension(extension_id)
if not extension_class:
raise RuntimeError("Unable to find data extension with ID=%s"
% job_id.dataset.extension_id)
job = scheduler.get_job(job_id)
if hasattr(job, 'extension_userdata'):
extension_userdata = job.extension_userdata
else:
extension_userdata = {}
extension_userdata.update({'is_inference_db': True})
extension = extension_class(**extension_userdata)

form = extension.get_inference_form()
if form:
template, context = extension.get_inference_template(form)
inference_form_html = flask.render_template_string(template, **context)

return inference_form_html
6 changes: 6 additions & 0 deletions digits/extensions/view/imageOutput/config_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
{{ form.pixel_conversion.tooltip }}
{{ form.pixel_conversion(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.show_input])}}">
{{ form.show_input.label }}
{{ form.show_input.tooltip }}
{{ form.show_input(class='form-control') }}
</div>
10 changes: 10 additions & 0 deletions digits/extensions/view/imageOutput/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@ class ConfigForm(Form):
tooltip='Select method to convert pixel values to the target bit '
'range'
)

show_input = utils.forms.SelectField(
'Show input as image',
choices=[
('yes', 'Yes'),
('no', 'No'),
],
default='no',
tooltip='Show input as image'
)
18 changes: 15 additions & 3 deletions digits/extensions/view/imageOutput/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, dataset, **kwargs):
self.channel_order = kwargs['channel_order'].upper()
self.data_order = kwargs['data_order'].upper()
self.normalize = (kwargs['pixel_conversion'] == 'normalize')
self.show_input = (kwargs['show_input'] == 'yes')

@staticmethod
def get_config_form():
Expand Down Expand Up @@ -70,17 +71,28 @@ def get_view_template(self, data):
- context is a dictionary of context variables to use for rendering
the form
"""
return self.view_template, {'image': digits.utils.image.embed_image_html(data)}
return self.view_template, {'image_input': digits.utils.image.embed_image_html(data[0]),
'image_output': digits.utils.image.embed_image_html(data[1])}

@override
def process_data(self, input_id, input_data, output_data):
"""
Process one inference and return data to visualize
"""

data = output_data[output_data.keys()[0]].astype('float32')
if self.show_input:
data_input = input_data.astype('float32')
image_input = self.process_image(self.data_order, data_input)
else:
image_input = None

data_output = output_data[output_data.keys()[0]].astype('float32')
image_output = self.process_image(self.data_order, data_output)

return [image_input, image_output]

if self.data_order == 'HWC':
def process_image(self, data_order, data):
if data_order == 'HWC':
data = (data.transpose((2, 0, 1)))

# assume CHW at this point
Expand Down
5 changes: 4 additions & 1 deletion digits/extensions/view/imageOutput/view_template.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

<img src="{{image}}" style="max-width:100%;" />
{% if image_input %}
<img src="{{image_input}}" style="max-width:100%;" />
{% endif %}
<img src="{{image_output}}" style="max-width:100%;" />
5 changes: 4 additions & 1 deletion digits/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .caffe_framework import CaffeFramework
from .framework import Framework
from .tensorflow_framework import TensorflowFramework
from .torch_framework import TorchFramework
from digits.config import config_value

Expand All @@ -13,6 +12,10 @@
'TorchFramework',
]

if config_value('tensorflow')['enabled']:
from .tensorflow_framework import TensorflowFramework
__all__.append('TensorflowFramework')

#
# create framework instances
#
Expand Down
3 changes: 1 addition & 2 deletions digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def can_accumulate_gradients(self):
if config_value('caffe')['flavor'] == 'BVLC':
return True
elif config_value('caffe')['flavor'] == 'NVIDIA':
return (parse_version(config_value('caffe')['version'])
> parse_version('0.14.0-alpha'))
return (parse_version(config_value('caffe')['version']) > parse_version('0.14.0-alpha'))
else:
raise ValueError('Unknown flavor. Support NVIDIA and BVLC flavors only.')
30 changes: 13 additions & 17 deletions digits/frameworks/tensorflow_framework.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import numpy as np
import os
import re
import subprocess
import time
import tempfile

import flask

from .errors import Error, NetworkVisualizationError, BadNetworkError
from .errors import NetworkVisualizationError
from .framework import Framework
import digits
from digits import utils
from digits.config import config_value
from digits.model.tasks import TensorflowTrainTask
from digits.utils import subclass, override, constants


@subclass
class TensorflowFramework(Framework):
"""
Expand All @@ -35,7 +32,7 @@ class TensorflowFramework(Framework):
SUPPORTS_PYTHON_LAYERS_FILE = False
SUPPORTS_TIMELINE_TRACING = True

SUPPORTED_SOLVER_TYPES = ['SGD','ADADELTA','ADAGRAD','ADAGRADDA','MOMENTUM','ADAM','FTRL','RMSPROP']
SUPPORTED_SOLVER_TYPES = ['SGD', 'ADADELTA', 'ADAGRAD', 'ADAGRADDA', 'MOMENTUM', 'ADAM', 'FTRL', 'RMSPROP']

SUPPORTED_DATA_TRANSFORMATION_TYPES = ['MEAN_SUBTRACTION', 'CROPPING']
SUPPORTED_DATA_AUGMENTATION_TYPES = ['FLIPPING', 'NOISE', 'CONTRAST', 'WHITENING', 'HSV_SHIFTING']
Expand All @@ -50,7 +47,7 @@ def create_train_task(self, **kwargs):
"""
create train task
"""
return TensorflowTrainTask(framework_id = self.framework_id, **kwargs)
return TensorflowTrainTask(framework_id=self.framework_id, **kwargs)

@override
def get_standard_network_desc(self, network):
Expand Down Expand Up @@ -126,10 +123,10 @@ def get_network_visualization(self, **kwargs):
# Another for the HTML
_, temp_html_path = tempfile.mkstemp(suffix='.html')

try: # do this in a try..finally clause to make sure we delete the temp file
try: # do this in a try..finally clause to make sure we delete the temp file
# build command line
args = [config_value('tensorflow')['executable'],
os.path.join(os.path.dirname(digits.__file__),'tools','tensorflow','main.py'),
os.path.join(os.path.dirname(digits.__file__), 'tools', 'tensorflow', 'main.py'),
'--network=%s' % os.path.basename(temp_network_path),
'--networkDirectory=%s' % os.path.dirname(temp_network_path),
'--visualizeModelPath=%s' % temp_graphdef_path,
Expand All @@ -141,7 +138,7 @@ def get_network_visualization(self, **kwargs):

if use_mean and use_mean != 'none':
mean_file = dataset.get_mean_file()
assert mean_file != None, 'Failed to retrieve mean file.'
assert mean_file is not None, 'Failed to retrieve mean file.'
args.append('--subtractMean=%s' % use_mean)
args.append('--mean=%s' % dataset.path(mean_file))

Expand All @@ -163,15 +160,14 @@ def get_network_visualization(self, **kwargs):

env = os.environ.copy()
# make only a selected number of GPUs visible. The ID is not important for just the vis
env['CUDA_VISIBLE_DEVICES'] = ",".join([str(i) for i in range(0,int(num_gpus))])
env['CUDA_VISIBLE_DEVICES'] = ",".join([str(i) for i in range(0, int(num_gpus))])

# execute command
p = subprocess.Popen(args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True,
env=env
)
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True,
env=env)

stdout_log = ''
while p.poll() is None:
Expand All @@ -181,7 +177,7 @@ def get_network_visualization(self, **kwargs):
stdout_log += line
if p.returncode:
raise NetworkVisualizationError(stdout_log)
else: # Success!
else: # Success!
return repr(str(open(temp_graphdef_path).read()))
finally:
os.remove(temp_network_path)
Expand Down
24 changes: 13 additions & 11 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ def validate_py_ext(form, field):
tooltip="How many epochs of training between running through one pass of the validation data?"
)

traces_interval = utils.forms.IntegerField('Tracing Interval (in steps)',
validators=[
validators.NumberRange(min=0)
],
default=0,
tooltip="Generation of a timeline trace every few steps"
)
traces_interval = utils.forms.IntegerField(
'Tracing Interval (in steps)',
validators=[
validators.NumberRange(min=0)
],
default=0,
tooltip="Generation of a timeline trace every few steps"
)

random_seed = utils.forms.IntegerField(
'Random seed',
Expand Down Expand Up @@ -311,10 +312,11 @@ def validate_lr_multistep_values(form, field):
)

def validate_custom_network_snapshot(form, field):
if form.method.data == 'custom':
for filename in field.data.strip().split(os.path.pathsep):
if filename and not os.path.exists(filename):
raise validators.ValidationError('File "%s" does not exist' % filename)
pass
#if form.method.data == 'custom':
# for filename in field.data.strip().split(os.path.pathsep):
# if filename and not os.path.exists(filename):
# raise validators.ValidationError('File "%s" does not exist' % filename)

# Select one of several GPUs
select_gpu = wtforms.RadioField(
Expand Down
2 changes: 1 addition & 1 deletion digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = task.get_snapshot(epoch)
snapshot_filename = task.get_snapshot(epoch, download=True)

# get model files
model_files = task.get_model_files()
Expand Down