Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FCN ObjectSegmentation with PyTorch backend #2041

Merged
merged 9 commits into from
Apr 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 86 additions & 25 deletions jsk_perception/node_scripts/fcn_object_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import chainer
from chainer import cuda
import chainer.serializers as S
import fcn
import fcn.models

import cv_bridge
from jsk_topic_tools import ConnectionBasedTransport
Expand All @@ -19,45 +14,82 @@
from sensor_msgs.msg import Image


def softmax(w, t=1.0):
e = np.exp(w)
dist = e / np.sum(e, axis=0)
return dist
is_torch_available = True
try:
import torch
except ImportError:
is_torch_available = False


def assert_torch_available():
if not is_torch_available:
url = 'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl' # NOQA
raise RuntimeError('Please install pytorch: pip install %s' % url)


class FCNObjectSegmentation(ConnectionBasedTransport):

def __init__(self):
super(self.__class__, self).__init__()
self.backend = rospy.get_param('~backend', 'chainer')
self.gpu = rospy.get_param('~gpu', -1) # -1 is cpu mode
self.target_names = rospy.get_param('~target_names')
self.bg_label = rospy.get_param('~bg_label', 0)
self.proba_threshold = rospy.get_param('~proba_threshold', 0.0)
self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
self._load_model()
self.pub = self.advertise('~output', Image, queue_size=1)
self.pub_proba = self.advertise(
'~output/proba_image', Image, queue_size=1)

def _load_model(self):
self.gpu = rospy.get_param('~gpu', -1) # -1 is cpu mode
self.model_name = rospy.get_param('~model_name')
model_h5 = rospy.get_param('~model_h5')
self.target_names = rospy.get_param('~target_names')
self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
if self.backend == 'chainer':
self._load_chainer_model()
elif self.backend == 'torch':
assert_torch_available()
self._load_torch_model()
else:
raise RuntimeError('Unsupported backend: %s', self.backend)

def _load_chainer_model(self):
model_name = rospy.get_param('~model_name')
model_h5 = rospy.get_param('~model_h5')
n_class = len(self.target_names)
if self.model_name == 'fcn32s':
if model_name == 'fcn32s':
self.model = fcn.models.FCN32s(n_class=n_class)
elif self.model_name == 'fcn16s':
elif model_name == 'fcn16s':
self.model = fcn.models.FCN16s(n_class=n_class)
elif self.model_name == 'fcn8s':
elif model_name == 'fcn8s':
self.model = fcn.models.FCN8s(n_class=n_class)
else:
rospy.logerr('Unsupported ~model_name: {0}'
.format(self.model_name))
raise ValueError('Unsupported ~model_name: {}'.format(model_name))
jsk_loginfo('Loading trained model: {0}'.format(model_h5))
S.load_hdf5(model_h5, self.model)
jsk_loginfo('Finished loading trained model: {0}'.format(model_h5))
if self.gpu != -1:
self.model.to_gpu(self.gpu)
self.model.train = False

def _load_torch_model(self):
try:
import torchfcn
except ImportError as e:
raise ImportError('Please install torchfcn: pip install torchfcn')
n_class = len(self.target_names)
model_file = rospy.get_param('~model_file')
model_name = rospy.get_param('~model_name')
if model_name == 'fcn32s':
self.model = torchfcn.models.FCN32s(n_class=n_class)
elif model_name == 'fcn32s_bilinear':
self.model = torchfcn.models.FCN32s(n_class=n_class, deconv=False)
else:
raise ValueError('Unsupported ~model_name: {0}'.format(model_name))
jsk_loginfo('Loading trained model: %s' % model_file)
self.model.load_state_dict(torch.load(model_file))
jsk_loginfo('Finished loading trained model: %s' % model_file)
if self.gpu >= 0:
self.model = self.model.cuda(self.gpu)
self.model.eval()

def subscribe(self):
use_mask = rospy.get_param('~use_mask', False)
Expand Down Expand Up @@ -114,22 +146,51 @@ def _cb(self, img_msg):
self.pub_proba.publish(proba_msg)

def segment(self, bgr):
if self.backend == 'chainer':
return self._segment_chainer_backend(bgr)
elif self.backend == 'torch':
assert_torch_available()
return self._segment_torch_backend(bgr)
raise ValueError('Unsupported backend: {0}'.format(self.backend))

def _segment_chainer_backend(self, bgr):
blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
x_data = np.array([blob], dtype=np.float32)
if self.gpu != -1:
x_data = cuda.to_gpu(x_data, device=self.gpu)
x = chainer.Variable(x_data, volatile=True)
self.model.train = False
self.model(x)
pred = self.model.score
score = cuda.to_cpu(pred.data)[0]
proba_img = softmax(score).transpose((1, 2, 0))
max_proba_img = np.max(proba_img, axis=-1)
label = np.argmax(score, axis=0)
proba_img = chainer.functions.softmax(self.model.score)[0]
proba_img = chainer.functions.transpose(proba_img, (1, 2, 0))
max_proba_img = chainer.functions.max(proba_img, axis=-1)
label = chainer.functions.argmax(self.model.score, axis=1)
# gpu -> cpu
proba_img = cuda.to_cpu(proba_img.data)[0]
max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
label = cuda.to_cpu(label.data)[0]
# uncertain because the probability is low
label[max_proba_img < self.proba_threshold] = self.bg_label
return label, proba_img

def _segment_torch_backend(self, bgr):
blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
x_data = np.array([blob], dtype=np.float32)
x_data = torch.from_numpy(x_data)
if self.gpu >= 0:
x_data = x_data.cuda(self.gpu)
x = torch.autograd.Variable(x_data, volatile=True)
score = self.model(x)
proba = torch.nn.functional.softmax(score)
max_proba, label = torch.max(proba, 1)
# uncertain because the probability is low
label[max_proba < self.proba_threshold] = self.bg_label
# gpu -> cpu
score = score.permute(0, 2, 3, 1).data.cpu().numpy()[0]
proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
max_proba = max_proba.data.cpu().numpy().squeeze((0, 1))
label = label.data.cpu().numpy().squeeze((0, 1))
return label, proba


if __name__ == '__main__':
rospy.init_node('fcn_object_segmentation')
Expand Down
1 change: 1 addition & 0 deletions jsk_perception/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
<run_depend>python-fcn-pip</run_depend> <!-- pip -->
<!-- }} install fcn -->
<run_depend>python-sklearn</run_depend>
<run_depend>python-torch-cuda80-pip</run_depend>
<run_depend>rosbag</run_depend>
<run_depend>roscpp</run_depend>
<run_depend>roseus</run_depend>
Expand Down