Permalink
Browse files

Add predict multibands support. Tiles and DataSet cleanup.

  • Loading branch information...
ocourtin committed Feb 9, 2019
1 parent 203b292 commit e1ce1cacfb2a57cb682b823330b250282b2f6111
@@ -7,19 +7,18 @@

import os
import sys
import torch

import numpy as np
from PIL import Image

import torch
import torch.utils.data
import cv2
import numpy as np

from robosat_pink.tiles import tiles_from_slippy_map, buffer_tile_image
from robosat_pink.tiles import tiles_from_slippy_map, tile_image_buffer, tile_image


# Single Slippy Map directory structure
class SlippyMapTiles(torch.utils.data.Dataset):
"""Dataset for images stored in slippy map format.
"""
"""Dataset for images stored in slippy map format. """

def __init__(self, root, mode, transform=None):
super().__init__()
@@ -37,30 +36,20 @@ def __len__(self):
def __getitem__(self, i):
tile, path = self.tiles[i]

if self.mode == "image":
image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)

elif self.mode == "multibands":
image = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
if len(image.shape) == 3 and image.shape[2] >= 3:
# FIXME Look twice to find an in-place way to perform a multiband BGR2RGB
g = image[:, :, 0]
image[:, :, 0] = image[:, :, 2]
image[:, :, 2] = g

elif self.mode == "mask":
if self.mode == "mask":
image = np.array(Image.open(path).convert("P"))

elif self.mode == "image":
image = tile_image(path)

if self.transform is not None:
image = self.transform(image)

return image, tile


# Multiple Slippy Map directories.
class SlippyMapTilesConcatenation(torch.utils.data.Dataset):
"""Dataset to concate multiple input images stored in slippy map format.
"""
"""Dataset to concate multiple input images stored in slippy map format. """

def __init__(self, path, channels, target, joint_transform=None):
super().__init__()
@@ -71,7 +60,7 @@ def __init__(self, path, channels, target, joint_transform=None):

for channel in channels:
for band in channel["bands"]:
self.inputs[channel["sub"]] = SlippyMapTiles(os.path.join(path, channel["sub"]), mode="multibands")
self.inputs[channel["sub"]] = SlippyMapTiles(os.path.join(path, channel["sub"]), mode="image")

self.target = SlippyMapTiles(target, mode="mask")

@@ -103,25 +92,14 @@ def __getitem__(self, i):
return tensor, mask, tile


# Todo: once we have the SlippyMapDataset this dataset should wrap
# it adding buffer and unbuffer glue on top of the raw tile dataset.
class BufferedSlippyMapDirectory(torch.utils.data.Dataset):
class BufferedSlippyMapTiles(torch.utils.data.Dataset):
"""Dataset for buffered slippy map tiles with overlap.
Note: The overlap must not span multiple tiles.
Use `unbuffer` to get back the original tile.
"""

def __init__(self, root, transform=None, size=512, overlap=32):
"""
Args:
root: the slippy map directory root with a `z/x/y.png` sub-structure.
transform: the transformation to run on the buffered tile.
size: the Slippy Map tile size in pixels
overlap: the tile border to add on every side; in pixel.
Note:
The overlap must not span multiple tiles.
Use `unbuffer` to get back the original tile.
"""

super().__init__()

@@ -138,23 +116,14 @@ def __len__(self):

def __getitem__(self, i):
tile, path = self.tiles[i]
image = np.array(buffer_tile_image(tile, self.tiles, overlap=self.overlap, tile_size=self.size))
image = np.array(tile_image_buffer(tile, self.tiles, overlap=self.overlap, tile_size=self.size))

if self.transform is not None:
image = self.transform(image)

return image, torch.IntTensor([tile.x, tile.y, tile.z])

def unbuffer(self, probs):
"""Removes borders from segmentation probabilities added to the original tile image.
Args:
probs: the segmentation probability mask to remove buffered borders.
Returns:
The probability mask with the original tile's dimensions without added overlap borders.
"""

o = self.overlap
_, x, y = probs.shape

@@ -1,119 +1,59 @@
"""Slippy Map Tiles.
The Slippy Map tile spec works with a directory structure of `z/x/y.png` where
- `z` is the zoom level
- `x` is the left / right index
- `y` is the top / bottom index
See: https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames
See: https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames
"""

import csv
import io
import os
from glob import glob
import re
import glob

import cv2
from PIL import Image
import numpy as np

import csv
import mercantile


def pixel_to_location(tile, dx, dy):
"""Converts a pixel in a tile to a coordinate.
Args:
tile: the mercantile tile to calculate the location in.
dx: the relative x offset in range [0, 1].
dy: the relative y offset in range [0, 1].
def tile_pixel_to_location(tile, dx, dy):
"""Converts a pixel in a tile to lon/lat coordinates."""

Returns:
The coordinate for the pixel in the tile.
"""
assert 0 <= dx <= 1 and 0 <= dy <= 1, "x and y offsets must be in [0, 1]"

assert 0 <= dx <= 1, "x offset is in [0, 1]"
assert 0 <= dy <= 1, "y offset is in [0, 1]"

west, south, east, north = mercantile.bounds(tile)
w, s, e, n = mercantile.bounds(tile)

def lerp(a, b, c):
return a + c * (b - a)

lon = lerp(west, east, dx)
lat = lerp(south, north, dy)

return lon, lat


def fetch_image(session, url, timeout=10):
"""Fetches the image representation for a tile.
Args:
session: the HTTP session to fetch the image from.
url: the tile imagery's url to fetch the image from.
timeout: the HTTP timeout in seconds.
Returns:
The satellite imagery as bytes or None in case of error.
"""

try:
resp = session.get(url, timeout=timeout)
resp.raise_for_status()
return io.BytesIO(resp.content)
except Exception:
return None
return lerp(w, e, dx), lerp(s, n, dy) # lon, lat


def tiles_from_slippy_map(root):
"""Loads files from an on-disk slippy map directory structure.
"""Loads files from an on-disk slippy map dir."""

Args:
root: the base directory with layout `z/x/y.*`.
Yields:
The mercantile tiles and file paths from the slippy map directory.
"""
root = os.path.expanduser(root)
paths = glob.glob(os.path.join(root, "[0-9]*/[0-9]*/[0-9]*.*"))

# The Python string functions (.isdigit, .isdecimal, etc.) handle
# unicode codepoints; we only care about digits convertible to int
def isdigit(v):
try:
_ = int(v) # noqa: F841
return True
except ValueError:
return False
for path in paths:

root = os.path.expanduser(root)
for z in os.listdir(root):
if not isdigit(z):
tile = re.match(os.path.join(root, "(?P<z>[0-9]+)/(?P<x>[0-9]+)/(?P<y>[0-9]+).+"), path)
if not tile:
continue

for x in os.listdir(os.path.join(root, z)):
if not isdigit(x):
continue

for name in os.listdir(os.path.join(root, z, x)):
y = os.path.splitext(name)[0]
yield mercantile.Tile(int(tile["x"]), int(tile["y"]), int(tile["z"])), path

if not isdigit(y):
continue

tile = mercantile.Tile(x=int(x), y=int(y), z=int(z))
path = os.path.join(root, z, x, name)
yield tile, path
def tile_from_slippy_map(root, x, y, z):
"""Retrieve a single tile from a slippy map dir."""

path = glob.glob(os.path.join(os.path.expanduser(root), z, x, y + ".*"))
if not path:
return None

def tiles_from_csv(path):
"""Read tiles from a line-delimited csv file.
return mercantile.Tile(x, y, z), path[0]

Args:
file: the path to read the csv file from.

Yields:
The mercantile tiles from the csv file.
"""
def tiles_from_csv(path):
"""Retrieve tiles from a line-delimited csv file."""

path = os.path.expanduser(path)
with open(path) as fp:
@@ -126,79 +66,65 @@ def tiles_from_csv(path):
yield mercantile.Tile(*map(int, row))


def tile_image(root, x, y, z):
"""Retrieves H,W,C numpy array, from a tile store and X,Y,Z coordinates, or `None`"""
def tile_image(path):
"""Return a multiband image numpy array, from a file path."""

try:
root = os.path.expanduser(root)
path = glob(os.path.join(root, z, x, y) + "*")
assert len(path) == 1
img = np.array(Image.open(path[0]).convert("RGB"))
except:
return None
path = os.path.expanduser(path)
image = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
if len(image.shape) == 3 and image.shape[2] >= 3: # multibands BGR2RGB
b = image[:, :, 0]
image[:, :, 0] = image[:, :, 2]
image[:, :, 2] = b

return img
return image


def adjacent_tile_image(tile, dx, dy, tiles):
"""Retrieves an adjacent tile image from a tile store.
def tile_image_fetch(session, url, timeout=10):
"""Fetch a tile image using HTTP. Need requests.Session."""

Args:
tile: the original tile to get an adjacent tile image for.
dx: the offset in tile x direction.
dy: the offset in tile y direction.
tiles: the tile store to get tiles from; must support `__getitem__` with tiles.
try:
resp = session.get(url, timeout=timeout)
resp.raise_for_status()
return io.BytesIO(resp.content)

except Exception:
return None

Returns:
The adjacent tile's image or `None` if it does not exist.
"""

x, y, z = map(int, [tile.x, tile.y, tile.z])
adjacent = mercantile.Tile(x=x + dx, y=y + dy, z=z)
def tile_image_adjacent(tile, dx, dy, tiles):
"""Retrieves an adjacent tile image if exists from a tile store, or None."""

try:
path = tiles[adjacent]
path = tiles[mercantile.Tile(x=int(tile.x) + dx, y=int(tile.y) + dy, z=int(tile.z))]
except KeyError:
return None

return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)


def buffer_tile_image(tile, tiles, overlap, tile_size):
"""Buffers a tile image adding borders on all sides based on adjacent tiles.
return tile_image(path)

Args:
tile: the tile to buffer.
tiles: available tiles; must be a mapping of tiles to their filesystem paths.
overlap: the tile border to add on every side; in pixel.
tile_size: the tile size.

Returns:
The H,W,C numpy composite image containing the original tile plus tile overlap on all sides.
It's size is `tile_size` + 2 * `overlap` pixel for each side.
"""
def tile_image_buffer(tile, tiles, overlap, tile_size):
"""Buffers a tile image adding borders on all sides based on adjacent tiles."""

assert 0 <= overlap <= tile_size, "Overlap value can't be either negative or bigger than tile_size"

tiles = dict(tiles)
x, y, z = map(int, [tile.x, tile.y, tile.z])

# 3x3 matrix (upper, center, bottom) x (left, center, right)
ul = adjacent_tile_image(tile, -1, -1, tiles)
uc = adjacent_tile_image(tile, +0, -1, tiles)
ur = adjacent_tile_image(tile, +1, -1, tiles)
cl = adjacent_tile_image(tile, -1, +0, tiles)
cc = adjacent_tile_image(tile, +0, +0, tiles)
cr = adjacent_tile_image(tile, +1, +0, tiles)
bl = adjacent_tile_image(tile, -1, +1, tiles)
bc = adjacent_tile_image(tile, +0, +1, tiles)
br = adjacent_tile_image(tile, +1, +1, tiles)
ul = tile_image_adjacent(tile, -1, -1, tiles)
uc = tile_image_adjacent(tile, +0, -1, tiles)
ur = tile_image_adjacent(tile, +1, -1, tiles)
cl = tile_image_adjacent(tile, -1, +0, tiles)
cc = tile_image_adjacent(tile, +0, +0, tiles)
cr = tile_image_adjacent(tile, +1, +0, tiles)
bl = tile_image_adjacent(tile, -1, +1, tiles)
bc = tile_image_adjacent(tile, +0, +1, tiles)
br = tile_image_adjacent(tile, +1, +1, tiles)

ts = tile_size
o = overlap
oo = overlap * 2

# Todo: instead of nodata we should probably mirror the center image
img = np.zeros((ts + oo, ts + oo, 3)).astype(np.uint8)

# fmt:off
Oops, something went wrong.

0 comments on commit e1ce1ca

Please sign in to comment.