Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add Cityscapes semantic segmentation dataset #392

Merged
merged 18 commits into from
Aug 19, 2017
4 changes: 4 additions & 0 deletions chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from chainercv.datasets.camvid.camvid_dataset import camvid_label_colors # NOQA
from chainercv.datasets.camvid.camvid_dataset import camvid_label_names # NOQA
from chainercv.datasets.camvid.camvid_dataset import CamVidDataset # NOQA
from chainercv.datasets.cityscapes.cityscapes_semantic_segmentation_dataset import CityscapesSemanticSegmentationDataset # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_label_colors # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_label_names # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_labels # NOQA
from chainercv.datasets.cub.cub_keypoint_dataset import CUBKeypointDataset # NOQA
from chainercv.datasets.cub.cub_label_dataset import CUBLabelDataset # NOQA
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import glob
import os

import numpy as np

from chainer import dataset
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_labels
from chainercv.utils import read_image


class CityscapesSemanticSegmentationDataset(dataset.DatasetMixin):

"""Dataset class for a semantic segmentation task on `Cityscapes dataset`_.

.. _`Cityscapes dataset`: https://www.cityscapes-dataset.com

.. note::

Please download the data by yourself because Cityscapes dataset doesn't
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this.

Please manually downalod the data because it is not allowed to re-distribute Cityscapes dataset.

allow to re-distribute their data.

Args:
img_dir (string): Path to the image dir. It should end with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dir->directory

``leftImg8bit``.
Copy link
Member

@yuyu2172 yuyu2172 Aug 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use obj.

label_dir (string): Path to the dir which contains labels. It should
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dir --> directory

end with either ``gtFine`` or ``gtCoarse``.
Copy link
Member

@yuyu2172 yuyu2172 Aug 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use obj.

split ({'train', 'val'}): Select from dataset splits used in
Cityscapes dataset.
ignore_labels (bool): If True, the labels marked ``ignoreInEval``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use obj.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True --> :obj:`True`

defined in the original
`cityscapesScripts<https://github.com/mcordts/cityscapesScripts>_`
will be replaced with `-1` in the `get_example` method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

`-1`  -->  :obj:`-1`
`get_example` -->  :meth:`get_example`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the following in the end.
The default value is :obj:True``.


"""

def __init__(self, img_dir, label_dir, split='train', ignore_labels=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about setting a default directory names?
if img_dir is None,
how about setting the img_dir as CHAINER_DATASET_ROOT/pfnet/chainercv/cityscapes/leftImg8bit?
Same for label_dir.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but for label_dir, we can't assume which label users use.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By having a default directory value, users do not need to specify the directory path once they set up their files properly. I think that this feature is very useful.

How about changing the options to data_dir and label_mode.
(i.e. def __init__(self, data_dir=None, label_mode=None, split='train', ignore_labels=True):)
The data_dir would be the path to the root dir whose default value is CHAINER_DATASET_ROOT/pfnet/chainercv/cityscapes.
Below the root directory, we expect at least two folders (leftImg8bit and a label directory that is going to be used).
The label_mode should raise an error when unspecified. It should be either fine or coarse.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll reflect your suggestion and add a test for it.

img_dir = os.path.join(img_dir, split)
self.ignore_labels = ignore_labels

self.label_fns, self.img_fns = [], []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using two lines and use list.
self.label_fnames = list()
self.img_fnames = list()

Copy link
Member

@Hakuyume Hakuyume Aug 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a check of []/list() to our coding style checker?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resol = os.path.basename(label_dir)
for dname in glob.glob('{}/*'.format(label_dir)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using os.path.join(label_dir, '*') instead?
This will work even if there is / at the end of img_dir.

if split in dname:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work for Coarse dataset as well?
From the Github page, it seems that there is train_extra split.
https://github.com/mcordts/cityscapesScripts

I have not yet looked at the dataset by myself, so please tell me if I am wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why I used if split in dname at L:59

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add that to the doc and test?

for label_fn in glob.glob(
'{}/*/*_labelIds.png'.format(dname)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

self.label_fns.append(label_fn)
for label_fn in self.label_fns:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#161 (comment)

fn is not used in ChainerCV.
However, I think that filenames is too long. An alternative would be fnames.

@Hakuyume

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry for the same mistake.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to use fnames?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer filenames to fnames. It is not so long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filenames is OK, but label_filenames and img_filenames are bit too long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about label_paths and img_paths?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference between *_filenames and *_paths is unclear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*_paths is shorter than *_filenames and _fnames. If the problem of *_filenames is its length, *_paths will be a good solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.
I think changing all *_filenames to *_paths is good. Leaving both is bad.

I will update other datasets accordingly.

@mitmul
Can you use *_paths?

img_fn = label_fn.replace(resol, 'leftImg8bit')
img_fn = img_fn.replace('_labelIds', '')
self.img_fns.append(img_fn)

def __len__(self):
return len(self.img_fns)

def get_example(self, i):
"""Returns the i-th example.

Returns a color image and a label image. The color image is in CHW
format and the label image is in HW format.

Args:
i (int): The index of the example.

Returns:
tuple of a color image and a label whose shapes are (3, H, W) and
(H, W) respectively. H and W are height and width of the image.
The dtype of the color image is :obj:`numpy.float32` and
the dtype of the label image is :obj:`numpy.int32`.

"""
img = read_image(self.img_fns[i])
label_orig = read_image(
self.label_fns[i], dtype=np.int32, color=False)[0]
H, W = label_orig.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can delete this and use np.ones(label_orgi.shape, dtype=np.int32) * -1 instead for line 102.

if self.ignore_labels:
label_out = np.ones((H, W), dtype=np.int32) * -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can optimize this part by

  1. not initializing label_out.
  2. for loop only labels in ignore lists

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When make the loop only over the ignore id list, how can I replace the label ids which are not marked as an ignoreInEval with trainId?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh. I see. My bad.
It looks OK.
Is it OK to remove line 79?
Also, np.where is not necessary.

for label in cityscapes_labels:
if label.ignoreInEval:
label_out[np.where(label_orig == label.id)] = -1
else:
label_out[np.where(label_orig == label.id)] = label.trainId
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove np.where.

else:
label_out = label
img = img.astype(np.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read_image guarantees that img is np.int32.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not np.float32?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry it is typo.

return img, label_out
54 changes: 54 additions & 0 deletions chainercv/datasets/cityscapes/cityscapes_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# The values used here is copied from cityscapesScripts:
# https://github.com/mcordts/cityscapesScripts

from collections import namedtuple


Label = namedtuple(
'Label', ['name', 'id', 'trainId', 'category', 'categoryId',
'hasInstances', 'ignoreInEval', 'color'])

cityscapes_labels = tuple([
Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
Label('egovehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
Label('rectificationborder', 2, 255, 'void', 0, False, True, (0, 0, 0)),
Label('outofroi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
Label('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
Label('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
Label('railtrack', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
Label('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
Label('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
Label('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
Label(
'guardrail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
Label('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
Label('trafficlight', 19, 6, 'object', 3, False, False, (250, 170, 30)),
Label('trafficsign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
Label('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
Label('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
Label('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
Label('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
Label('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
Label('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
Label('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
Label('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
Label('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
Label('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
Label('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
Label('licenseplate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
])

cityscapes_label_names = tuple(
l.name for l in cityscapes_labels if not l.ignoreInEval)

cityscapes_label_colors = tuple(
l.color for l in cityscapes_labels if not l.ignoreInEval)
59 changes: 59 additions & 0 deletions tests/datasets_tests/cityscapes_tests/test_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
from PIL import Image

from chainer import testing
from chainer.testing import attr
from chainercv.datasets import cityscapes_label_names
from chainercv.datasets import CityscapesSemanticSegmentationDataset
from chainercv.utils import assert_is_semantic_segmentation_dataset


@testing.parameterize(
{'split': 'train'},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test ignore_labels (True, False).

{'split': 'val'}
)
class TestCityscapesSemanticSegmentationDataset(unittest.TestCase):

def setUp(self):
self.temp_dir = tempfile.mkdtemp()
img_dir = os.path.join(
self.temp_dir, 'leftImg8bit/{}/aachen'.format(self.split))
label_dir = os.path.join(
self.temp_dir, 'gtFine/{}/aachen'.format(self.split))
os.makedirs(img_dir)
os.makedirs(label_dir)

for i in range(10):
img = np.random.randint(0, 255, size=(128, 160, 3))
img = Image.fromarray(img.astype(np.uint8))
img.save(os.path.join(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use write_image.
#382

img_dir, 'aachen_000000_0000{:02d}_leftImg8bit.png'.format(i)))

label = np.random.randint(0, 20, size=(128, 160)).astype(np.uint8)
label = Image.fromarray(np.zeros((128, 160), dtype=np.uint8))
label.save(os.path.join(
label_dir,
'aachen_000000_0000{:02d}_gtFine_labelIds.png'.format(i)))

img_dir = os.path.join(self.temp_dir, 'leftImg8bit')
label_dir = os.path.join(self.temp_dir, 'gtFine')
if self.split == 'test':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary.

label_dir = None
self.dataset = CityscapesSemanticSegmentationDataset(
img_dir, label_dir, self.split)

def tearDown(self):
shutil.rmtree(self.temp_dir)

@attr.slow
def test_cityscapes_semantic_segmentation_dataset(self):
assert_is_semantic_segmentation_dataset(
self.dataset, len(cityscapes_label_names), n_example=10)


testing.run_module(__name__, __file__)