Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge branch 'Cityscapes_Loader' into super-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Sep 8, 2017
2 parents 50e582c + 85b0385 commit 38366c6
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 95 deletions.
176 changes: 82 additions & 94 deletions inferno/io/box/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import zipfile
import io
import os
import torch.utils.data as data
from PIL import Image
from os.path import join
from os.path import join, relpath, abspath
from ...utils.exceptions import assert_
from ..transform.base import Compose
from ..transform.generic import \
Normalize, NormalizeRange, Cast, AsTorchBatch, Project, Label2OneHot
from ..transform.image import \
RandomSizedCrop, RandomGammaCorrection, RandomFlip, Scale, PILImage2NumPyArray
from ..core import Concatenate


CITYSCAPES_CLASSES = {
0: 'unlabeled',
Expand Down Expand Up @@ -174,25 +177,43 @@ def get_matching_labelimage_file(f, groundtruth):
return '/'.join(fs)


def make_dataset(image_zip_file, split):
def get_filelist(path):
if path.endswith('.zip'):
return zipfile.ZipFile(path, 'r').filelist
elif os.path.isdir(path):
return [relpath(join(root, filename), abspath(join(path, '..')))
for root, _, filenames in os.walk(path) for filename in filenames]
else:
raise NotImplementedError("Path must be a zip archive or a directory.")


def make_dataset(path, split):
images = []
for f in zipfile.ZipFile(image_zip_file, 'r').filelist:
fn = f.filename.split('/')
if fn[-1].endswith('.png') and fn[1] == split:
for f in get_filelist(path):
if isinstance(f, str):
fn = f
fns = f.split('/')
else:
fn = f.filename
fns = f.filename.split('/')
if fns[-1].endswith('.png') and fns[1] == split:
# use first folder name to identify train/val/test images
if split == 'train_extra':
groundtruth = 'gtCoarse'
else:
groundtruth = 'gtFine'

fl = get_matching_labelimage_file(f.filename, groundtruth)
fl = get_matching_labelimage_file(fn, groundtruth)
images.append((f, fl))
return images


def extract_image(archive, image_path):
# read image directly from zipfile
return Image.open(io.BytesIO(zipfile.ZipFile(archive, 'r').read(image_path)))
def extract_image(path, image_path):
if path.endswith('.zip'):
# read image directly from zipfile if path is a zip
return Image.open(io.BytesIO(zipfile.ZipFile(path, 'r').read(image_path)))
else:
return Image.open(join(abspath(join(path, '..')), image_path), 'r')


class Cityscapes(data.Dataset):
Expand All @@ -211,7 +232,7 @@ class Cityscapes(data.Dataset):
MEAN = CITYSCAPES_MEAN
STD = CITYSCAPES_STD

def __init__(self, root_folder, split='train',
def __init__(self, root_folder, split='train', read_from_zip_archive=True,
image_transform=None, label_transform=None, joint_transform=None):
"""
Parameters:
Expand All @@ -224,26 +245,23 @@ def __init__(self, root_folder, split='train',
"`split` must be one of {}".format(set(self.SPLIT_NAME_MAPPING.keys())),
KeyError)
self.split = self.SPLIT_NAME_MAPPING.get(split)
self.read_from_zip_archive = read_from_zip_archive

# Data path
if self.split == 'train_extra':
self.image_zip_file = join(root_folder, 'leftImg8bit_trainextra.zip')
self.label_zip_file = join(root_folder, 'gtCoarse.zip')
else:
self.image_zip_file = join(root_folder, 'leftImg8bit_trainvaltest.zip')
self.label_zip_file = join(root_folder, 'gtFine_trainvaltest.zip')
# Get roots
self.image_root, self.label_root = [join(root_folder, groot)
for groot in self.get_image_and_label_roots()]

# Transforms
self.image_transform = image_transform
self.label_transform = label_transform
self.joint_transform = joint_transform
# Make list with paths to the images
self.image_paths = make_dataset(self.image_zip_file, self.split)
self.image_paths = make_dataset(self.image_root, self.split)

def __getitem__(self, index):
pi, pl = self.image_paths[index]
image = extract_image(self.image_zip_file, pi)
label = extract_image(self.label_zip_file, pl)
image = extract_image(self.image_root, pi)
label = extract_image(self.label_root, pl)
# Apply transforms
if self.image_transform is not None:
image = self.image_transform(image)
Expand All @@ -261,9 +279,27 @@ def download(self):
# https://www.cityscapes-dataset.com/
raise NotImplementedError


def get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), labels_as_onehot=False,
train_batch_size=1, validate_batch_size=1, num_workers=2):
def get_image_and_label_roots(self):
all_roots = {
'zipped':
{
'train': ('leftImg8bit_trainvaltest.zip', 'gtFine_trainvaltest.zip'),
'val': ('leftImg8bit_trainvaltest.zip', 'gtFine_trainvaltest.zip'),
'train_extra': ('leftImg8bit_trainextra.zip', 'gtCoarse.zip')
},
'unzipped':
{
'train': ('leftImg8bit', 'gtFine'),
'val': ('leftImg8bit', 'gtFine'),
'train_extra': ('leftImg8bit', 'gtCoarse')
}
}
image_and_label_roots = all_roots\
.get('zipped' if self.read_from_zip_archive else 'unzipped').get(self.split)
return image_and_label_roots


def make_transforms(image_shape, labels_as_onehot):
# Make transforms
image_transforms = Compose(PILImage2NumPyArray(),
NormalizeRange(),
Expand All @@ -287,90 +323,42 @@ def get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), labels_as_o
# Applying Label2OneHot on the full label image makes it unnecessarily expensive,
# because we're throwing it away with RandomSizedCrop and Scale. Tests show that it's
# ~1 sec faster per image.
joint_transforms\
joint_transforms \
.add(Label2OneHot(num_classes=len(CITYSCAPES_LABEL_WEIGHTS), dtype='bool',
apply_to=[1]))\
apply_to=[1])) \
.add(Cast('float', apply_to=[1]))
else:
# Cast label image to long
joint_transforms.add(Cast('long', apply_to=[1]))

# Batchify
joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False))

# Return as kwargs
return {'image_transform': image_transforms,
'label_transform': label_transforms,
'joint_transform': joint_transforms}


def get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), labels_as_onehot=False,
include_coarse_dataset=False, read_from_zip_archive=True,
train_batch_size=1, validate_batch_size=1, num_workers=2):
# Build datasets
train_dataset = Cityscapes(root_directory, split='train',
image_transform=image_transforms,
label_transform=label_transforms,
joint_transform=joint_transforms)
read_from_zip_archive=read_from_zip_archive,
**make_transforms(image_shape, labels_as_onehot))
if include_coarse_dataset:
# Build coarse dataset
coarse_dataset = Cityscapes(root_directory, split='train_extra',
read_from_zip_archive=read_from_zip_archive,
**make_transforms(image_shape, labels_as_onehot))
# ... and concatenate with train_dataset
train_dataset = Concatenate(coarse_dataset, train_dataset)
validate_dataset = Cityscapes(root_directory, split='validate',
image_transform=image_transforms,
label_transform=label_transforms,
joint_transform=joint_transforms)

read_from_zip_archive=read_from_zip_archive,
**make_transforms(image_shape, labels_as_onehot))

# Build loaders
train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size,
shuffle=True, num_workers=num_workers, pin_memory=True)
validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size,
shuffle=True, num_workers=num_workers, pin_memory=True)

return train_loader, validate_loader



def get_cityscapes_train_loader(root_directory, dataset='train', image_shape=(1024, 2048), labels_as_onehot=False,
batch_size=1, num_workers=2):

DATASET_NAME_MAPPING = {'train': 'train',
'training': 'train',
'training_extra': 'train_extra',
'train_extra': 'train_extra'}

assert_(dataset in DATASET_NAME_MAPPING.keys(),
"`dataset` must be one of {}".format(set(DATASET_NAME_MAPPING.keys())), KeyError)
dataset_name = DATASET_NAME_MAPPING.get(dataset)

# Make transforms
image_transforms = Compose(PILImage2NumPyArray(),
NormalizeRange(),
RandomGammaCorrection(),
Normalize(mean=CITYSCAPES_MEAN, std=CITYSCAPES_STD))
label_transforms = Compose(PILImage2NumPyArray(),
Project(projection=CITYSCAPES_CLASSES_TO_LABELS))
joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0),
preserve_aspect_ratio=True),
# Scale raw image back to the original shape
Scale(output_image_shape=image_shape,
interpolation_order=3, apply_to=[0]),
# Scale segmentation back to the original shape
# (without interpolation)
Scale(output_image_shape=image_shape,
interpolation_order=0, apply_to=[1]),
RandomFlip(allow_ud_flips=False),
# Cast raw image to float
Cast('float', apply_to=[0]))
if labels_as_onehot:
# Applying Label2OneHot on the full label image makes it unnecessarily expensive,
# because we're throwing it away with RandomSizedCrop and Scale. Tests show that it's
# ~1 sec faster per image.
joint_transforms\
.add(Label2OneHot(num_classes=len(CITYSCAPES_LABEL_WEIGHTS), dtype='bool',
apply_to=[1]))\
.add(Cast('float', apply_to=[1]))
else:
# Cast label image to long
joint_transforms.add(Cast('long', apply_to=[1]))

# Batchify
joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False))

# Build datasets
train_dataset = Cityscapes(root_directory, split=dataset_name,
image_transform=image_transforms,
label_transform=label_transforms,
joint_transform=joint_transforms)

loader = data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers, pin_memory=True)

return loader
52 changes: 51 additions & 1 deletion tests/io/box/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class TestCityscapes(unittest.TestCase):
CITYSCAPES_ROOT = None
PLOT_DIRECTORY = join(dirname(__file__), 'plots')
INCLUDE_COARSE = False

def get_cityscapes_root(self):
if self.CITYSCAPES_ROOT is None:
Expand All @@ -26,11 +27,24 @@ def test_cityscapes_dataset_without_transforms(self):
self.assertSequenceEqual(label.shape, (1024, 2048))
self.assertLessEqual(label.max(), 33)

def test_cityscapes_dataset_without_transforms_unzipped(self):
from inferno.io.box.cityscapes import Cityscapes
cityscapes = Cityscapes(join(self.get_cityscapes_root(), 'extracted'),
read_from_zip_archive=False)
image, label = cityscapes[0]
image = np.asarray(image)
label = np.asarray(label)
self.assertSequenceEqual(image.shape, (1024, 2048, 3))
self.assertSequenceEqual(label.shape, (1024, 2048))
self.assertLessEqual(label.max(), 33)

def test_cityscapes_dataset_with_transforms(self):
from inferno.io.box.cityscapes import get_cityscapes_loaders
from inferno.utils.io_utils import print_tensor

train_loader, validate_loader = get_cityscapes_loaders(self.get_cityscapes_root())
train_loader, validate_loader = get_cityscapes_loaders(self.get_cityscapes_root(),
include_coarse_dataset=
self.INCLUDE_COARSE)
train_dataset = train_loader.dataset
tic = time.time()
image, label = train_dataset[0]
Expand All @@ -56,6 +70,42 @@ def test_cityscapes_dataset_with_transforms(self):
directory=self.PLOT_DIRECTORY)
print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))

def test_cityscapes_dataset_with_transforms_unzipped(self):
from inferno.io.box.cityscapes import get_cityscapes_loaders
from inferno.utils.io_utils import print_tensor

train_loader, validate_loader = get_cityscapes_loaders(join(self.get_cityscapes_root(),
'extracted'),
include_coarse_dataset=
self.INCLUDE_COARSE,
read_from_zip_archive=False)
train_dataset = train_loader.dataset
tic = time.time()
image, label = train_dataset[0]
toc = time.time()
print("[+] Loaded sample in {} seconds.".format(toc - tic))
# Make sure the shapes checkout
self.assertSequenceEqual(image.size(), (3, 1024, 2048))
self.assertSequenceEqual(label.size(), (1024, 2048))
self.assertEqual(image.type(), 'torch.FloatTensor')
self.assertEqual(label.type(), 'torch.LongTensor')
# Print tensors to make sure they look legit
if not exists(self.PLOT_DIRECTORY):
os.mkdir(self.PLOT_DIRECTORY)
else:
assert isdir(self.PLOT_DIRECTORY)
print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)
for class_id in np.unique(label.numpy()):
print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'),
prefix='LAB-{}--'.format(class_id),
directory=self.PLOT_DIRECTORY)
print_tensor(label.numpy()[None, None, ...],
prefix='LAB--',
directory=self.PLOT_DIRECTORY)
print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))


if __name__ == '__main__':
TestCityscapes.CITYSCAPES_ROOT = '/home/nrahaman/BigHeronHDD2/CityScapes'
TestCityscapes.INCLUDE_COARSE = True
unittest.main()

0 comments on commit 38366c6

Please sign in to comment.