Skip to content

Commit

Permalink
Merge pull request NVIDIA#961 from gheinrich/dev/semantic-segmentation
Browse files Browse the repository at this point in the history
Semantic segmentation work-flow
  • Loading branch information
lukeyeager committed Aug 30, 2016
2 parents 2529523 + 82af74a commit 3188c00
Show file tree
Hide file tree
Showing 54 changed files with 1,676 additions and 6 deletions.
31 changes: 26 additions & 5 deletions digits/dataset/generic/views.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import caffe_pb2
import flask
import PIL.Image

# Find the best implementation available
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO

import caffe_pb2
import flask
import matplotlib as mpl
import numpy as np
import PIL.Image

from .forms import GenericDatasetForm
from .job import GenericDatasetJob

from digits import extensions, utils
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from digits.utils.routing import request_wants_json, job_from_request
from digits.utils.lmdbreader import DbReader
from digits.webapp import scheduler
Expand Down Expand Up @@ -145,6 +147,18 @@ def explore():
db_path = job.path(db)
labels = []

if COLOR_PALETTE_ATTRIBUTE in job.extension_userdata:
# assume single-channel 8-bit palette
palette = job.extension_userdata[COLOR_PALETTE_ATTRIBUTE]
palette = np.array(palette).reshape((len(palette)/3,3)) / 255.
# normalize input pixels to [0,1]
norm = mpl.colors.Normalize(vmin=0,vmax=255)
# create map
cmap = mpl.pyplot.cm.ScalarMappable(norm=norm,
cmap=mpl.colors.ListedColormap(palette))
else:
cmap = None

page = int(flask.request.args.get('page', 0))
size = int(flask.request.args.get('size', 25))

Expand All @@ -167,6 +181,13 @@ def explore():
s.write(datum.data)
s.seek(0)
img = PIL.Image.open(s)
if cmap and img.mode in ['L', '1']:
data = np.array(img)
data = cmap.to_rgba(data)*255
data = data.astype('uint8')
# keep RGB values only, remove alpha channel
data = data[:, :, 0:3]
img = PIL.Image.fromarray(data)
imgs.append({"label": None, "b64": utils.image.embed_image_html(img)})
count += 1
if len(imgs) >= size:
Expand Down
2 changes: 2 additions & 0 deletions digits/extensions/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from . import imageGradients
from . import imageProcessing
from . import imageSegmentation
from . import objectDetection

data_extensions = [
Expand All @@ -11,6 +12,7 @@
# editing DIGITS config option 'data_extension_list'
{'class': imageGradients.DataIngestion, 'show': False},
{'class': imageProcessing.DataIngestion, 'show': True},
{'class': imageSegmentation.DataIngestion, 'show': True},
{'class': objectDetection.DataIngestion, 'show': True},
]

Expand Down
4 changes: 4 additions & 0 deletions digits/extensions/data/imageSegmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from .data import DataIngestion
203 changes: 203 additions & 0 deletions digits/extensions/data/imageSegmentation/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import math
import os
import random

import numpy as np
import PIL.Image

from digits.utils import image, subclass, override, constants
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from ..interface import DataIngestionInterface
from .forms import DatasetForm

TEMPLATE = "template.html"


@subclass
class DataIngestion(DataIngestionInterface):
"""
A data ingestion extension for an image segmentation dataset
"""

def __init__(self, **kwargs):
"""
the parent __init__ method automatically populates this
instance with attributes from the form
"""
super(DataIngestion, self).__init__(**kwargs)

self.random_indices = None

if not 'seed' in self.userdata:
# choose random seed and add to userdata so it gets persisted
self.userdata['seed'] = random.randint(0, 1000)

random.seed(self.userdata['seed'])

# open first image in label folder to retrieve palette
# all label images must use the same palette - this is enforced
# during dataset creation
filename = self.make_image_list(self.label_folder)[0]
image = self.load_label(filename)
self.userdata[COLOR_PALETTE_ATTRIBUTE] = image.getpalette()

# get labels if those were provided
if self.class_labels_file:
with open(self.class_labels_file) as f:
self.userdata['class_labels'] = f.read().splitlines()

@override
def encode_entry(self, entry):
"""
Return numpy.ndarray
"""
feature_image_file = entry[0]
label_image_file = entry[1]

# feature image
feature_image = self.encode_PIL_Image(
image.load_image(feature_image_file),
self.channel_conversion)

# label image
label_image = self.load_label(label_image_file)
if label_image.getpalette() != self.userdata[COLOR_PALETTE_ATTRIBUTE]:
raise ValueError("All label images must use the same palette")
label_image = self.encode_PIL_Image(label_image)

return feature_image, label_image

def encode_PIL_Image(self, image, channel_conversion='none'):
if channel_conversion != 'none':
if image.mode != channel_conversion:
# convert to different image mode if necessary
image = image.convert(channel_conversion)
# convert to numpy array
image = np.array(image)
# add channel axis if input is grayscale image
if image.ndim == 2:
image = image[..., np.newaxis]
elif image.ndim != 3:
raise ValueError("Unhandled number of channels: %d" % image.ndim)
# transpose to CHW
image = image.transpose(2, 0, 1)
return image

@staticmethod
@override
def get_category():
return "Images"

@staticmethod
@override
def get_id():
return "image-segmentation"

@staticmethod
@override
def get_dataset_form():
return DatasetForm()

@staticmethod
@override
def get_dataset_template(form):
"""
parameters:
- form: form returned by get_dataset_form(). This may be populated
with values if the job was cloned
returns:
- (template, context) tuple
- template is a Jinja template to use for rendering dataset creation
options
- context is a dictionary of context variables to use for rendering
the form
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@staticmethod
@override
def get_title():
return "Segmentation"

@override
def itemize_entries(self, stage):
if stage == constants.TEST_DB:
# don't retun anything for the test stage
return []

if stage == constants.TRAIN_DB or (not self.has_val_folder):
feature_image_list = self.make_image_list(self.feature_folder)
label_image_list = self.make_image_list(self.label_folder)
else:
# separate validation images
feature_image_list = self.make_image_list(self.validation_feature_folder)
label_image_list = self.make_image_list(self.validation_label_folder)

# make sure filenames match
if len(feature_image_list) != len(label_image_list):
raise ValueError(
"Expect same number of images in feature and label folders (%d!=%d)"
% (len(feature_image_list), len(label_image_list)))

for idx in range(len(feature_image_list)):
feature_name = os.path.splitext(
os.path.split(feature_image_list[idx])[1])[0]
label_name = os.path.splitext(
os.path.split(label_image_list[idx])[1])[0]
if feature_name != label_name:
raise ValueError("No corresponding feature/label pair found for (%s,%s)"
% (feature_name, label_name) )

# split lists if there is no val folder
if not self.has_val_folder:
feature_image_list = self.split_image_list(feature_image_list, stage)
label_image_list = self.split_image_list(label_image_list, stage)

return zip(
feature_image_list,
label_image_list)

def load_label(self, filename):
"""
Load a label image
"""
image = PIL.Image.open(filename)
if image.mode not in ['P', 'L', '1']:
raise ValueError("Labels are expected to be single-channel (paletted or "
" grayscale) images - %s mode is '%s'"
% (filename, image.mode))
return image

def make_image_list(self, folder):
image_files = []
for dirpath, dirnames, filenames in os.walk(folder, followlinks=True):
for filename in filenames:
if filename.lower().endswith(image.SUPPORTED_EXTENSIONS):
image_files.append('%s' % os.path.join(folder, filename))
if len(image_files) == 0:
raise ValueError("Unable to find supported images in %s" % folder)
return sorted(image_files)

def split_image_list(self, filelist, stage):
if self.random_indices is None:
self.random_indices = range(len(filelist))
random.shuffle(self.random_indices)
elif len(filelist) != len(self.random_indices):
raise ValueError(
"Expect same number of images in folders (%d!=%d)"
% (len(filelist), len(self.random_indices)))
filelist = [filelist[idx] for idx in self.random_indices]
pct_val = int(self.folder_pct_val)
n_val_entries = int(math.floor(len(filelist) * pct_val / 100))
if stage == constants.VAL_DB:
return filelist[:n_val_entries]
elif stage == constants.TRAIN_DB:
return filelist[n_val_entries:]
else:
raise ValueError("Unknown stage: %s" % stage)
Loading

0 comments on commit 3188c00

Please sign in to comment.