In [None]:
import concurrent.futures
import os
import sys

import geopandas as gpd
import numpy as np
import rasterio
from rasterio import windows
from rasterio.features import geometry_window
from tqdm import tqdm

# append parent dir to sys path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.logger import OutputLogger

os.environ["GDAL_PAM_ENABLED"] = "NO"

In [None]:
tile_width = 512
tile_height = 512
tile_overlap = 256

out_filename = "{}_{}_{}.tif"
images_dir = "/net/data_ssd/tree_mortality_orthophotos/orthophotos/"
masks_dir = "/net/scratch/jmoehring/masks/"
lables_dir = "/net/data_ssd/tree_mortality_orthophotos/labels_and_aois/"

images_out_dir = "/net/scratch/jmoehring/tiles/images/"
masks_out_dir = "/net/scratch/jmoehring/tiles/masks/"

jobs = 50

In [None]:
def get_windows(xmin, ymin, xmax, ymax, tile_width, tile_height, overlap):
    xstep = tile_width - overlap
    ystep = tile_height - overlap
    for x in range(xmin, xmax, xstep):
        if x + tile_width > xmax:
            x = xmax - tile_width
        for y in range(ymin, ymax, ystep):
            if y + tile_height > ymax:
                y = ymax - tile_height
            window = windows.Window(x, y, tile_width, tile_height)
            yield window

In [None]:
def process_file(image_filepath, mask_dir, label_dir, image_out_dir, mask_out_dir):
    image_filename = os.path.basename(image_filepath)

    mask_filename = image_filename.replace(".tif", "_mask.tif")
    mask_filepath = os.path.join(mask_dir, mask_filename)

    label_filename = image_filename.replace(".tif", "_polygons.gpkg")
    label_filepath = os.path.join(label_dir, label_filename)

    with rasterio.open(image_filepath) as image_src:
        with rasterio.open(mask_filepath) as mask_src:
            image_metadata = image_src.meta.copy()
            mask_metadata = mask_src.meta.copy()

            gdf_label = gpd.read_file(label_filepath, layer="aoi")
            gdf_label = gdf_label.to_crs(image_src.crs)

            for _, row in gdf_label.iterrows():
                aoi_window = geometry_window(image_src, [row.geometry])
                xmin, ymin = aoi_window.col_off, aoi_window.row_off
                xmax, ymax = xmin + aoi_window.width, ymin + aoi_window.height

                for window in get_windows(
                    xmin, ymin, xmax, ymax, tile_width, tile_height, tile_overlap
                ):
                    transform = windows.transform(window, image_src.transform)
                    image_metadata["transform"] = transform
                    image_metadata["width"], image_metadata["height"] = (
                        window.width,
                        window.height,
                    )
                    mask_metadata["transform"] = transform
                    mask_metadata["width"], mask_metadata["height"] = (
                        window.width,
                        window.height,
                    )
                    image_out_filepath = os.path.join(
                        image_out_dir,
                        out_filename.format(
                            image_filename.replace(".tif", ""),
                            window.col_off,
                            window.row_off,
                        ),
                    )
                    mask_out_filepath = os.path.join(
                        mask_out_dir,
                        out_filename.format(
                            mask_filename.replace(".tif", ""),
                            window.col_off,
                            window.row_off,
                        ),
                    )

                    if not os.path.exists(image_out_filepath) and not os.path.exists(
                        mask_out_filepath
                    ):
                        out_image = image_src.read(window=window)

                        filled_fraction = 0
                        if image_metadata["nodata"] is None:
                            filled_fraction = (
                                np.count_nonzero(out_image) / out_image.size
                            )
                        else:
                            filled_fraction = (
                                np.count_nonzero(out_image != image_metadata["nodata"])
                                / out_image.size
                            )

                        if filled_fraction > 0.99:
                            out_mask = mask_src.read(window=window)
                            with rasterio.open(
                                image_out_filepath, "w", **image_metadata
                            ) as dst:
                                dst.write(out_image)
                            with rasterio.open(
                                mask_out_filepath, "w", **mask_metadata
                            ) as dst:
                                dst.write(out_mask)

In [5]:
%%capture output
with tqdm(total=len(os.listdir(images_dir))) as pbar:
    with concurrent.futures.ProcessPoolExecutor(max_workers=jobs) as executor:
        futures = [
            executor.submit(
                process_file,
                os.path.join(images_dir, filename),
                masks_dir,
                lables_dir,
                images_out_dir,
                masks_out_dir,
            )
            for filename in os.listdir(images_dir)
        ]
        # Wait for all futures to complete
        for _ in concurrent.futures.as_completed(futures):
            pbar.update(1)