-
Notifications
You must be signed in to change notification settings - Fork 306
Add Cityscapes semantic segmentation dataset #392
Changes from 17 commits
5488d03
5173ad1
6f12fa3
7a7661b
e838475
45e1a3c
5abfee7
62960b7
2deef29
4fc0d78
b6e9f99
07d6b8f
13b66c7
77e0aa7
e4fee67
d50abf8
5b28c0e
0740b5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import glob | ||
import os | ||
|
||
import numpy as np | ||
|
||
from chainer import dataset | ||
from chainer.dataset import download | ||
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_labels | ||
from chainercv.utils import read_image | ||
|
||
|
||
class CityscapesSemanticSegmentationDataset(dataset.DatasetMixin): | ||
|
||
"""Dataset class for a semantic segmentation task on `Cityscapes dataset`_. | ||
|
||
.. _`Cityscapes dataset`: https://www.cityscapes-dataset.com | ||
|
||
.. note:: | ||
|
||
Please manually downalod the data because it is not allowed to | ||
re-distribute Cityscapes dataset. | ||
|
||
Args: | ||
data_dir (string): Path to the dataset directory. The directory should | ||
contain at least two directories, :obj:`leftImg8bit` and either | ||
:obj:`gtFine` or :obj:`gtCoarse`. If :obj:`None` is given, it uses | ||
:obj:`$CHAINER_DATSET_ROOT/pfnet/chainercv/cityscapes` by default. | ||
label_resolutionution ({'fine', 'coarse'}): The resolution of the | ||
labels. It should be either :obj:`fine` or :obj:`coarse`. | ||
split ({'train', 'val'}): Select from dataset splits used in | ||
Cityscapes dataset. | ||
ignore_labels (bool): If True, the labels marked :obj:`ignoreInEval` | ||
defined in the original | ||
`cityscapesScripts<https://github.com/mcordts/cityscapesScripts>_` | ||
will be replaced with :obj:`-1` in the :meth:`get_example` method. | ||
The default value is :obj:`True`. | ||
|
||
""" | ||
|
||
def __init__(self, data_dir=None, label_resolution=None, split='train', | ||
ignore_labels=True): | ||
if data_dir is None: | ||
data_dir = download.get_dataset_directory( | ||
'pfnet/chainercv/cityscapes') | ||
if label_resolution not in ['fine', 'coarse']: | ||
raise ValueError('\'label_resolution\' argment should be eighter ' | ||
'\'fine\' or \'coarse\'.') | ||
|
||
img_dir = os.path.join(data_dir, os.path.join('leftImg8bit', split)) | ||
resol = 'gtFine' if label_resolution == 'fine' else 'gtCoarse' | ||
label_dir = os.path.join(data_dir, resol) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can easily anticipate users to instantiate this object expecting that ChainerCV would download the dataset. Perhaps raise an error here when either of the two necessary directories do not exist? The error message can be something like this.
|
||
if not os.path.exists(img_dir) or not os.path.exists(label_dir): | ||
raise ValueError( | ||
'Cityscapes dataset does not exist at the expected location.' | ||
'Please download it from https://www.cityscapes-dataset.com/.' | ||
'Then place directory leftImg8bit at {} and {} at {}.'.format( | ||
os.path.join(data_dir, 'leftImg8bit'), resol, label_dir)) | ||
|
||
self.ignore_labels = ignore_labels | ||
|
||
self.label_paths = list() | ||
self.img_paths = list() | ||
city_dnames = list() | ||
for dname in glob.glob(os.path.join(label_dir, '*')): | ||
if split in dname: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it work for Coarse dataset as well? I have not yet looked at the dataset by myself, so please tell me if I am wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's why I used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add that to the doc and test? |
||
for city_dname in glob.glob(os.path.join(dname, '*')): | ||
for label_path in glob.glob( | ||
os.path.join(city_dname, '*_labelIds.png')): | ||
self.label_paths.append(label_path) | ||
city_dnames.append(os.path.basename(city_dname)) | ||
for city_dname, label_path in zip(city_dnames, self.label_paths): | ||
label_path = os.path.basename(label_path) | ||
img_path = label_path.replace( | ||
'{}_labelIds'.format(resol), 'leftImg8bit') | ||
img_path = os.path.join(img_dir, city_dname, img_path) | ||
self.img_paths.append(img_path) | ||
|
||
def __len__(self): | ||
return len(self.img_paths) | ||
|
||
def get_example(self, i): | ||
"""Returns the i-th example. | ||
|
||
Returns a color image and a label image. The color image is in CHW | ||
format and the label image is in HW format. | ||
|
||
Args: | ||
i (int): The index of the example. | ||
|
||
Returns: | ||
tuple of a color image and a label whose shapes are (3, H, W) and | ||
(H, W) respectively. H and W are height and width of the image. | ||
The dtype of the color image is :obj:`numpy.float32` and | ||
the dtype of the label image is :obj:`numpy.int32`. | ||
|
||
""" | ||
img = read_image(self.img_paths[i]) | ||
label_orig = read_image( | ||
self.label_paths[i], dtype=np.int32, color=False)[0] | ||
if self.ignore_labels: | ||
label_out = np.ones(label_orig.shape, dtype=np.int32) * -1 | ||
for label in cityscapes_labels: | ||
if not label.ignoreInEval: | ||
label_out[label_orig == label.id] = label.trainId | ||
else: | ||
label_out = label_orig | ||
return img, label_out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# The values used here is copied from cityscapesScripts: | ||
# https://github.com/mcordts/cityscapesScripts | ||
|
||
from collections import namedtuple | ||
|
||
|
||
Label = namedtuple( | ||
'Label', ['name', 'id', 'trainId', 'category', 'categoryId', | ||
'hasInstances', 'ignoreInEval', 'color']) | ||
|
||
cityscapes_labels = tuple([ | ||
Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), | ||
Label('egovehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), | ||
Label('rectificationborder', 2, 255, 'void', 0, False, True, (0, 0, 0)), | ||
Label('outofroi', 3, 255, 'void', 0, False, True, (0, 0, 0)), | ||
Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), | ||
Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), | ||
Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), | ||
Label('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), | ||
Label('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), | ||
Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), | ||
Label('railtrack', 10, 255, 'flat', 1, False, True, (230, 150, 140)), | ||
Label('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), | ||
Label('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), | ||
Label('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), | ||
Label( | ||
'guardrail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), | ||
Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), | ||
Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), | ||
Label('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), | ||
Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), | ||
Label('trafficlight', 19, 6, 'object', 3, False, False, (250, 170, 30)), | ||
Label('trafficsign', 20, 7, 'object', 3, False, False, (220, 220, 0)), | ||
Label('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), | ||
Label('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), | ||
Label('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), | ||
Label('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), | ||
Label('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), | ||
Label('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), | ||
Label('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), | ||
Label('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), | ||
Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), | ||
Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), | ||
Label('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), | ||
Label('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), | ||
Label('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), | ||
Label('licenseplate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), | ||
]) | ||
|
||
cityscapes_label_names = tuple( | ||
l.name for l in cityscapes_labels if not l.ignoreInEval) | ||
|
||
cityscapes_label_colors = tuple( | ||
l.color for l in cityscapes_labels if not l.ignoreInEval) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from chainer import testing | ||
from chainer.testing import attr | ||
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_labels | ||
from chainercv.datasets import CityscapesSemanticSegmentationDataset | ||
from chainercv.utils import assert_is_semantic_segmentation_dataset | ||
from chainercv.utils import write_image | ||
|
||
|
||
@testing.parameterize( | ||
{'split': 'train', 'n_class': 19, 'label_mode': 'fine', | ||
'ignore_labels': True}, | ||
{'split': 'val', 'n_class': 34, 'label_mode': 'coarse', | ||
'ignore_labels': False} | ||
) | ||
class TestCityscapesSemanticSegmentationDataset(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.temp_dir = tempfile.mkdtemp() | ||
img_dir = os.path.join( | ||
self.temp_dir, 'leftImg8bit/{}/aachen'.format(self.split)) | ||
resol = 'gtFine' if self.label_mode == 'fine' else 'gtCoarse' | ||
label_dir = os.path.join( | ||
self.temp_dir, '{}/{}/aachen'.format(resol, self.split)) | ||
os.makedirs(img_dir) | ||
os.makedirs(label_dir) | ||
|
||
for i in range(10): | ||
img = np.random.randint( | ||
0, 255, size=(3, 128, 160)).astype(np.uint8) | ||
write_image(img, os.path.join( | ||
img_dir, 'aachen_000000_0000{:02d}_leftImg8bit.png'.format(i))) | ||
|
||
label = np.random.randint( | ||
0, 34, size=(1, 128, 160)).astype(np.int32) | ||
write_image(label, os.path.join( | ||
label_dir, | ||
'aachen_000000_0000{:02d}_{}_labelIds.png'.format(i, resol))) | ||
|
||
self.dataset = CityscapesSemanticSegmentationDataset( | ||
self.temp_dir, self.label_mode, self.split, self.ignore_labels) | ||
|
||
def test_ignore_labels(self): | ||
for _, label_orig in self.dataset: | ||
H, W = label_orig.shape | ||
label_out = np.ones((H, W), dtype=np.int32) * -1 | ||
for label in cityscapes_labels: | ||
label_out[label_orig == label.trainId] = label.id | ||
|
||
def tearDown(self): | ||
shutil.rmtree(self.temp_dir) | ||
|
||
@attr.slow | ||
def test_cityscapes_semantic_segmentation_dataset(self): | ||
assert_is_semantic_segmentation_dataset( | ||
self.dataset, self.n_class, n_example=10) | ||
|
||
|
||
testing.run_module(__name__, __file__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo