Skip to content

Commit

Permalink
Merge pull request #2 from ethantang95/Tim-Work
Browse files Browse the repository at this point in the history
Basic Tensorflow Support
  • Loading branch information
ethantang95 committed Jun 23, 2017
2 parents 1bce875 + ea25e1a commit 880198b
Show file tree
Hide file tree
Showing 75 changed files with 16,080 additions and 180 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Temporary files
*.swp
*~
.DS_Store
TAGS

# Compiled / optimized files
Expand Down
2 changes: 1 addition & 1 deletion .gjslintrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--max_line_length=120
--exclude_directories=3rdparty
--exclude_directories=3rdparty,tb
--disable=0121,0220
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ env:
- DIGITS_TEST_FRAMEWORK=caffe CAFFE_FORK=NVIDIA
- DIGITS_TEST_FRAMEWORK=caffe CAFFE_FORK=BVLC
- DIGITS_TEST_FRAMEWORK=torch
- DIGITS_TEST_FRAMEWORK=tensorflow
- DIGITS_TEST_FRAMEWORK=none

matrix:
Expand Down Expand Up @@ -82,6 +83,7 @@ addons:
- cmake
- cython
- git
- gfortran
- graphviz
- libboost-filesystem-dev
- libboost-python-dev
Expand Down Expand Up @@ -128,6 +130,7 @@ install:
- echo "backend:agg" > ~/.config/matplotlib/matplotlibrc
- ./scripts/travis/install-caffe.sh $CAFFE_ROOT
- if [ "$DIGITS_TEST_FRAMEWORK" == "torch" ]; then travis_wait ./scripts/travis/install-torch.sh $TORCH_ROOT; else unset TORCH_ROOT; fi
- if [ "$DIGITS_TEST_FRAMEWORK" == "tensorflow" ]; then travis_wait ./scripts/travis/install-tensorflow.sh; fi
- pip install -r ./requirements.txt
- pip install -r ./requirements_test.txt
- pip install -e .
Expand Down
2 changes: 2 additions & 0 deletions digits/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
torch,
server_name,
store_option,
tensorflow,
)


Expand All @@ -20,3 +21,4 @@ def config_value(option):
Return the current configuration value for the given option
"""
return option_list[option]

46 changes: 46 additions & 0 deletions digits/config/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os
import platform
from subprocess import Popen,PIPE

from . import option_list

VARNAME_ENV_TFPY = 'TENSORFLOW_PYTHON'
DEFAULT_PYTHON_EXE = 'python2' # @TODO(tzaman) - use the python executable that was used to launch digits?

if platform.system() == 'Darwin':
# DYLD_LIBRARY_PATH and LD_LIBRARY_PATH is sometimes stripped, and the cuda libraries might need it
if not "DYLD_LIBRARY_PATH" in os.environ:
if "CUDA_HOME" in os.environ:
os.environ["DYLD_LIBRARY_PATH"] = str(os.environ["CUDA_HOME"] + '/lib')

def test_tf_import(python_exe):
"""
Tests if tensorflow can be imported, returns if it went okay and optional error.
"""
p = Popen([python_exe, "-c", "import tensorflow"], stdout=PIPE, stderr=PIPE)
(out, err) = p.communicate()
return p.returncode==0, str(err)

if VARNAME_ENV_TFPY in os.environ:
tf_python_exe = os.environ[VARNAME_ENV_TFPY]
else:
tf_python_exe = DEFAULT_PYTHON_EXE

tf_enabled, err = test_tf_import(tf_python_exe)

if not tf_enabled:
print('Tensorflow support disabled.')
# print('Failed importing Tensorflow with python executable "%s"\n%s' % (tf_python_exe, err))

if tf_enabled:
option_list['tensorflow'] = {
'enabled': True,
'executable': tf_python_exe,
}
else:
option_list['tensorflow'] = {
'enabled': False,
}
3 changes: 3 additions & 0 deletions digits/dataset/images/classification/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ class ImageClassificationDatasetForm(ImageDatasetForm):
choices=[
('lmdb', 'LMDB'),
('hdf5', 'HDF5'),
('tfrecords', 'TFRecords'),
],
default='lmdb',
)

def validate_backend(form, field):
if field.data == 'lmdb':
form.compression.data = 'none'
elif field.data == 'tfrecords':
form.compression.data = 'none'
elif field.data == 'hdf5':
form.encoding.data = 'none'

Expand Down
3 changes: 2 additions & 1 deletion digits/dataset/tasks/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def __init__(self, input_file, db_name, backend, image_dims, **kwargs):
self.input_file = input_file
self.db_name = db_name
self.backend = backend
if backend == 'hdf5':
if backend == 'hdf5' or backend == 'tfrecords':
# the list of hdf5 files is stored in a textfile
# tfrecords can be sharded as well
self.textfile = os.path.join(self.db_name, 'list.txt')
self.image_dims = image_dims
if image_dims[2] == 3:
Expand Down
6 changes: 6 additions & 0 deletions digits/extensions/view/imageOutput/config_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
{{ form.channel_order(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.data_order])}}">
{{ form.data_order.label }}
{{ form.data_order.tooltip }}
{{ form.data_order(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.pixel_conversion])}}">
{{ form.pixel_conversion.label }}
{{ form.pixel_conversion.tooltip }}
Expand Down
12 changes: 12 additions & 0 deletions digits/extensions/view/imageOutput/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ class ConfigForm(Form):
'is ignored in the case of a grayscale image)'
)

data_order = utils.forms.SelectField(
'Data order',
choices=[
('chw', 'CHW'),
('hwc', 'HWC'),
],
default='chw',
tooltip="Set the order of the data. For Caffe and Torch models this "
"is often CHW, for Tensorflow it's HWC."
"W=Width, H=Height, C=Channels"
)

pixel_conversion = utils.forms.SelectField(
'Pixel conversion',
choices=[
Expand Down
8 changes: 7 additions & 1 deletion digits/extensions/view/imageOutput/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, dataset, **kwargs):

# view options
self.channel_order = kwargs['channel_order'].upper()
self.data_order = kwargs['data_order'].upper()
self.normalize = (kwargs['pixel_conversion'] == 'normalize')

@staticmethod
Expand Down Expand Up @@ -76,8 +77,13 @@ def process_data(self, input_id, input_data, output_data):
"""
Process one inference and return data to visualize
"""
# assume the only output is a CHW image

data = output_data[output_data.keys()[0]].astype('float32')

if self.data_order == 'HWC':
data = (data.transpose((2, 0, 1)))

# assume CHW at this point
channels = data.shape[0]
if channels == 3 and self.channel_order == 'BGR':
data = data[[2, 1, 0], ...] # BGR to RGB
Expand Down
6 changes: 6 additions & 0 deletions digits/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .caffe_framework import CaffeFramework
from .framework import Framework
from .tensorflow_framework import TensorflowFramework
from .torch_framework import TorchFramework
from digits.config import config_value

Expand All @@ -19,6 +20,9 @@
# torch is optional
torch = TorchFramework() if config_value('torch')['enabled'] else None

# tensorflow is optional
tensorflow = TensorflowFramework() if config_value('tensorflow')['enabled'] else None

# caffe is mandatory
caffe = CaffeFramework()

Expand All @@ -35,6 +39,8 @@ def get_frameworks():
frameworks = [caffe]
if torch:
frameworks.append(torch)
if tensorflow:
frameworks.append(tensorflow)
return frameworks


Expand Down
5 changes: 4 additions & 1 deletion digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class CaffeFramework(Framework):

# whether this framework can shuffle data during training
CAN_SHUFFLE_DATA = False
SUPPORTS_PYTHON_LAYERS_FILE = True
SUPPORTS_TIMELINE_TRACING = False

if config_value('caffe')['flavor'] == 'NVIDIA':
if parse_version(config_value('caffe')['version']) > parse_version('0.14.0-alpha'):
Expand Down Expand Up @@ -132,10 +134,11 @@ def get_network_from_path(self, path):
return network

@override
def get_network_visualization(self, desc):
def get_network_visualization(self, **kwargs):
"""
return visualization of network
"""
desc = kwargs['desc']
net = caffe_pb2.NetParameter()
text_format.Merge(desc, net)
# Throws an error if name is None
Expand Down
14 changes: 13 additions & 1 deletion digits/frameworks/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def can_shuffle_data(self):
"""
return self.CAN_SHUFFLE_DATA

def supports_python_layers_file(self):
"""
return whether framework can shuffle input data during training
"""
return self.SUPPORTS_PYTHON_LAYERS_FILE

def supports_timeline_traces(self):
"""
return whether framework supports creating timeline traces
"""
return self.SUPPORTS_TIMELINE_TRACING

def supports_solver_type(self, solver_type):
"""
return whether framework supports this solver_type
Expand Down Expand Up @@ -77,7 +89,7 @@ def get_network_from_path(self, path):
"""
raise NotImplementedError('Please implement me')

def get_network_visualization(self, desc):
def get_network_visualization(self, **kwargs):
"""
return visualization of network
"""
Expand Down

0 comments on commit 880198b

Please sign in to comment.