Skip to content

Commit

Permalink
Read TIFF chips using rasterio instead of PIL to allow reading non-RG…
Browse files Browse the repository at this point in the history
…B TIFF chips (#1932)

* read TIFF chips using rasterio instead of PIL

To allow reading non-rgb TIFFs.

* add unit test

---------

Co-authored-by: Adeel Hassan <ahassan@element84.com>
  • Loading branch information
AdeelH and AdeelH committed Sep 29, 2023
1 parent 6eb5a1e commit ff5e031
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from torchvision.datasets.folder import (IMG_EXTENSIONS, DatasetFolder)
from PIL import Image
import rasterio as rio

IMG_EXTENSIONS = tuple([*IMG_EXTENSIONS, '.npy'])

Expand Down Expand Up @@ -37,6 +38,10 @@ def load_image(path: PathLike) -> np.ndarray:
ext = splitext(path)[-1]
if ext == '.npy':
img = np.load(path)
elif ext == '.tif' or ext == '.tiff':
with rio.open(path, 'r') as f:
img = f.read()
img = img.transpose(1, 2, 0)
else:
img = np.array(Image.open(path))

Expand Down
10 changes: 10 additions & 0 deletions tests/pytorch_learner/dataset/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from tempfile import TemporaryDirectory

import numpy as np
import rasterio as rio
from torchvision.datasets.folder import DatasetFolder

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core.data.utils import write_window
from rastervision.pytorch_learner.dataset import (discover_images, load_image,
make_image_folder_dataset)
from rastervision.pytorch_backend.pytorch_learner_backend import write_chip
Expand Down Expand Up @@ -53,6 +55,14 @@ def test_load_image(self):
write_chip(chip, path)
np.testing.assert_array_equal(load_image(path), chip)

chip = np.random.randint(
0, 256, size=(100, 100, 8), dtype=np.uint8)
path = join(tmp_dir, '4.tif')
profile = dict(height=100, width=100, count=8, dtype=np.uint8)
with rio.open(path, 'w', **profile) as ds:
write_window(ds, chip)
np.testing.assert_array_equal(load_image(path), chip)

def test_make_image_folder_dataset(self):
with get_tmp_dir() as tmp_dir:
with TemporaryDirectory(dir=tmp_dir) as dir_a, TemporaryDirectory(
Expand Down

0 comments on commit ff5e031

Please sign in to comment.