Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Write a test
Browse files Browse the repository at this point in the history
  • Loading branch information
mitmul committed Aug 15, 2017
1 parent 5488d03 commit 5173ad1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import os

import numpy as np

from chainer import dataset
from chainercv.datasets import cityscapes_labels
from chainercv.utils import read_image

from datasets import cityscapes_labels


class CityscapesSemanticSegmentationDataset(dataset.DatasetMixin):

Expand All @@ -22,7 +22,7 @@ class CityscapesSemanticSegmentationDataset(dataset.DatasetMixin):
``leftImg8bit``.
label_dir (string): Path to the dir which contains labels. It should
end with either ``gtFine`` or ``gtCoarse``.
split ({'train', 'val', 'test'}): Select from dataset splits used in
split ({'train', 'val'}): Select from dataset splits used in
Cityscapes dataset.
ignore_labels (bool): If True, the labels marked ``ignoreInEval``
defined in the original
Expand All @@ -36,23 +36,16 @@ def __init__(self, img_dir, label_dir, split='train', ignore_labels=True):
self.ignore_labels = ignore_labels

self.label_fns, self.img_fns = [], []
if label_dir is not None:
resol = os.path.basename(label_dir)
for dname in glob.glob('{}/*'.format(label_dir)):
if split in dname:
for label_fn in glob.glob(
'{}/*/*_labelIds.png'.format(dname)):
self.label_fns.append(label_fn)
for label_fn in self.label_fns:
img_fn = label_fn.replace(resol, 'leftImg8bit')
img_fn = img_fn.replace('_labelIds', '')
self.img_fns.append(img_fn)
else:
for dname in glob.glob('{}/*'.format(img_dir)):
if split in dname:
for img_fn in glob.glob(
'{}/*_leftImg8bit.png'.format(dname)):
self.img_fns.append(img_fn)
resol = os.path.basename(label_dir)
for dname in glob.glob('{}/*'.format(label_dir)):
if split in dname:
for label_fn in glob.glob(
'{}/*/*_labelIds.png'.format(dname)):
self.label_fns.append(label_fn)
for label_fn in self.label_fns:
img_fn = label_fn.replace(resol, 'leftImg8bit')
img_fn = img_fn.replace('_labelIds', '')
self.img_fns.append(img_fn)

def __len__(self):
return len(self.img_fns)
Expand All @@ -74,8 +67,6 @@ def get_example(self, i):
"""
img = read_image(self.img_fns[i])
if self.label_fns == []:
return img
label_orig = read_image(
self.label_fns[i], dtype=np.int32, color=False)[0]
H, W = label_orig.shape
Expand Down
46 changes: 39 additions & 7 deletions tests/datasets_tests/cityscapes_tests/test_cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,59 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
from PIL import Image

from chainer import testing
from chainer.testing import attr

from chainercv.datasets import cityscapes_label_names
from chainercv.datasets import CityscapesSemanticSegmentationDataset
from chainercv.utils import assert_is_semantic_segmentation_dataset


@testing.parameterize(
{'split': 'train'},
{'split': 'val'},
{'split': 'test'}
{'split': 'val'}
)
class TestCamVidDataset(unittest.TestCase):
class TestCityscapesSemanticSegmentationDataset(unittest.TestCase):

def setUp(self):
self.dataset = CamVidDataset(split=self.split)
self.temp_dir = tempfile.mkdtemp()
img_dir = os.path.join(
self.temp_dir, 'leftImg8bit/{}/aachen'.format(self.split))
label_dir = os.path.join(
self.temp_dir, 'gtFine/{}/aachen'.format(self.split))
os.makedirs(img_dir)
os.makedirs(label_dir)

for i in range(10):
img = np.random.randint(0, 255, size=(128, 160, 3))
img = Image.fromarray(img.astype(np.uint8))
img.save(os.path.join(
img_dir, 'aachen_000000_0000{:02d}_leftImg8bit.png'.format(i)))

label = np.random.randint(0, 20, size=(128, 160)).astype(np.uint8)
label = Image.fromarray(np.zeros((128, 160), dtype=np.uint8))
label.save(os.path.join(
label_dir,
'aachen_000000_0000{:02d}_gtFine_labelIds.png'.format(i)))

img_dir = os.path.join(self.temp_dir, 'leftImg8bit')
label_dir = os.path.join(self.temp_dir, 'gtFine')
if self.split == 'test':
label_dir = None
self.dataset = CityscapesSemanticSegmentationDataset(
img_dir, label_dir, self.split)

def tearDown(self):
shutil.rmtree(self.temp_dir)

@attr.slow
def test_camvid_dataset(self):
def test_cityscapes_semantic_segmentation_dataset(self):
assert_is_semantic_segmentation_dataset(
self.dataset, len(camvid_label_names), n_example=10)
self.dataset, len(cityscapes_label_names), n_example=10)


testing.run_module(__name__, __file__)

0 comments on commit 5173ad1

Please sign in to comment.