In [2]:
import os
import concurrent.futures

import geopandas as gpd
import pandas as pd
import numpy as np
import rasterio
from rasterio import windows
from rasterio.features import geometry_window
from tqdm import tqdm
import utm


os.environ["GDAL_PAM_ENABLED"] = "NO"
os.environ["GDAL_NUM_THREADS"] = "50"

In [3]:
tile_width = 512
tile_height = 512
tile_overlap = 256

cell_widths = np.arange(0.02, 0.22, 0.02)

out_filename = "{}_{}_{}.tif"
images_dir = "/net/scratch/jmoehring/images_select/"
masks_dir = "/net/scratch/jmoehring/masks_select/"
labels_dir = "/net/data_ssd/tree_mortality_orthophotos/labels_and_aois/"

tiles_out_dir = "/net/scratch/jmoehring/tiles_select/"

jobs = 3

In [3]:
def shorten_list(arr, target):
    index = next((i for i, x in enumerate(arr) if x > target), None)
    if index is not None:
        shortened_list = arr[index:]
        return shortened_list
    else:
        return []

In [4]:
def get_windows(xmin, ymin, xmax, ymax, tile_width, tile_height, overlap):
    xstep = tile_width - overlap
    ystep = tile_height - overlap
    for x in range(xmin, xmax, xstep):
        if x + tile_width > xmax:
            x = xmax - tile_width
        for y in range(ymin, ymax, ystep):
            if y + tile_height > ymax:
                y = ymax - tile_height
            window = windows.Window(x, y, tile_width, tile_height)
            yield window

In [5]:
def get_utm_crs(dr):
    utm_northern_range = range(32601, 32661)
    utm_southern_range = range(32701, 32761)
    epsg_code = dr.crs.to_epsg()
    if epsg_code in utm_northern_range or epsg_code in utm_southern_range:
        return dr.crs
    if epsg_code == 4326:
        zone = utm.from_latlon(dr.transform[0], dr.transform[2])
        utm_code = 32600 + zone[2]
        if zone[3] == "S":
            utm_code += 100
        utm_crs = f"EPSG:{utm_code}"
        return utm_crs
    else:
        raise ValueError(f"Unknown CRS: {epsg_code}")

In [6]:
def reproject_dataset_to_utm(dataset, resampling_method):
    utm_crs = get_utm_crs(dataset)
    default_transform, width, height = rasterio.warp.calculate_default_transform(
        dataset.crs, utm_crs, dataset.width, dataset.height, *dataset.bounds
    )
    kwargs = dataset.meta.copy()
    kwargs.update(
        {
            "crs": utm_crs,
            "transform": default_transform,
            "width": width,
            "height": height,
        }
    )
    memfile = rasterio.io.MemoryFile()
    with memfile.open(**kwargs, compress="DEFLATE") as dst:
        for i in range(1, dataset.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(dataset, i),
                destination=rasterio.band(dst, i),
                src_transform=dataset.transform,
                src_crs=dataset.crs,
                dst_transform=default_transform,
                dst_crs=utm_crs,
                resampling=resampling_method,
            )
    return memfile

In [1]:
def rescale_dataset_to_cell_width(dataset, cell_width, resampling_method):
    kwargs = dataset.meta.copy()
    kwargs.update(
        {
            "width": int(
                np.ceil(dataset.width * abs(dataset.transform.a) / cell_width)
            ),
            "height": int(
                np.ceil(dataset.height * abs(dataset.transform.e) / cell_width)
            ),
            "transform": rasterio.Affine(
                cell_width,
                0.0,
                dataset.transform.c,
                0.0,
                -cell_width,
                dataset.transform.f,
            ),
        }
    )
    memfile = rasterio.io.MemoryFile()
    with memfile.open(
        **kwargs,
        compress="DEFLATE",
        tiled=True,
        blockxsize=tile_overlap,
        blockysize=tile_overlap
    ) as dst:
        for i in range(1, dataset.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(dataset, i),
                destination=rasterio.band(dst, i),
                src_transform=dataset.transform,
                src_crs=dataset.crs,
                dst_transform=kwargs["transform"],
                dst_crs=dataset.crs,
                resampling=resampling_method,
            )
    return memfile

SyntaxError: invalid syntax (1177094634.py, line 37)

In [8]:
register_df = pd.DataFrame(
    columns=["image_path", "mask_path", "resolution", "x", "y", "label_quality"]
)

In [9]:
def process_file(image_filename):
    image_filepath = os.path.join(images_dir, image_filename)

    mask_filename = image_filename.replace(".tif", "_mask.tif")
    mask_filepath = os.path.join(masks_dir, mask_filename)

    label_filename = image_filename.replace(".tif", "_polygons.gpkg")
    label_filepath = os.path.join(labels_dir, label_filename)

    with rasterio.open(image_filepath) as idr, rasterio.open(mask_filepath) as mdr:
        image_memfile = reproject_dataset_to_utm(
            idr, rasterio.enums.Resampling.bilinear
        )
        mask_memfile = reproject_dataset_to_utm(mdr, rasterio.enums.Resampling.nearest)
        image_repojected = rasterio.open(image_memfile, "r+")
        mask_repojected = rasterio.open(mask_memfile, "r+")

        gdf_label = gpd.read_file(label_filepath, layer="aoi")
        gdf_label = gdf_label.to_crs(image_repojected.crs)

        out_dir = os.path.join(tiles_out_dir, image_filename.replace(".tif", ""))
        os.makedirs(out_dir, exist_ok=True)
        os.makedirs(os.path.join(out_dir, "orig"), exist_ok=True)

        for cell_width in tqdm(shorten_list(cell_widths, abs(idr.transform[0]))):
            image_rescaled_memfile = rescale_dataset_to_cell_width(
                image_repojected, cell_width, rasterio.enums.Resampling.bilinear
            )
            mask_rescaled_memfile = rescale_dataset_to_cell_width(
                mask_repojected, cell_width, rasterio.enums.Resampling.nearest
            )
            image_rescaled = rasterio.open(image_rescaled_memfile, "r+")
            mask_rescaled = rasterio.open(mask_rescaled_memfile, "r+")

            image_rescaled.close()
            mask_rescaled.close()

        image_repojected.close()
        mask_repojected.close()

In [10]:
for filename in tqdm(os.listdir(images_dir)):
    process_file(filename)

100%|██████████| 10/10 [01:39<00:00, 10.00s/it]
100%|██████████| 1/1 [02:17<00:00, 137.51s/it]


In [11]:
# def process_file(image_filepath, mask_dir, label_dir, image_out_dir, mask_out_dir):
#     image_filename = os.path.basename(image_filepath)

#     mask_filename = image_filename.replace(".tif", "_mask.tif")
#     mask_filepath = os.path.join(mask_dir, mask_filename)

#     label_filename = image_filename.replace(".tif", "_polygons.gpkg")
#     label_filepath = os.path.join(label_dir, label_filename)

#     with rasterio.open(image_filepath) as image_src:
#         with rasterio.open(mask_filepath) as mask_src:
#             image_metadata = image_src.meta.copy()
#             mask_metadata = mask_src.meta.copy()

#             gdf_label = gpd.read_file(label_filepath, layer="aoi")
#             gdf_label = gdf_label.to_crs(image_src.crs)

#             for _, row in gdf_label.iterrows():
#                 aoi_window = geometry_window(image_src, [row.geometry])
#                 xmin, ymin = aoi_window.col_off, aoi_window.row_off
#                 xmax, ymax = xmin + aoi_window.width, ymin + aoi_window.height

#                 for window in get_windows(
#                     xmin, ymin, xmax, ymax, tile_width, tile_height, tile_overlap
#                 ):
#                     transform = windows.transform(window, image_src.transform)
#                     image_metadata["transform"] = transform
#                     image_metadata["width"], image_metadata["height"] = (
#                         window.width,
#                         window.height,
#                     )
#                     mask_metadata["transform"] = transform
#                     mask_metadata["width"], mask_metadata["height"] = (
#                         window.width,
#                         window.height,
#                     )
#                     image_out_filepath = os.path.join(
#                         image_out_dir,
#                         out_filename.format(
#                             image_filename.replace(".tif", ""),
#                             window.col_off,
#                             window.row_off,
#                         ),
#                     )
#                     mask_out_filepath = os.path.join(
#                         mask_out_dir,
#                         out_filename.format(
#                             mask_filename.replace(".tif", ""),
#                             window.col_off,
#                             window.row_off,
#                         ),
#                     )

#                     if not os.path.exists(image_out_filepath) and not os.path.exists(
#                         mask_out_filepath
#                     ):
#                         out_image = image_src.read(window=window)

#                         filled_fraction = 0
#                         if image_metadata["nodata"] is None:
#                             filled_fraction = (
#                                 np.count_nonzero(out_image) / out_image.size
#                             )
#                         else:
#                             filled_fraction = (
#                                 np.count_nonzero(out_image != image_metadata["nodata"])
#                                 / out_image.size
#                             )

#                         if filled_fraction > 0.99:
#                             out_mask = mask_src.read(window=window)
#                             with rasterio.open(
#                                 image_out_filepath, "w", **image_metadata
#                             ) as dst:
#                                 dst.write(out_image)
#                             with rasterio.open(
#                                 mask_out_filepath, "w", **mask_metadata
#                             ) as dst:
#                                 dst.write(out_mask)

In [12]:
# with tqdm(total=len(os.listdir(images_dir))) as pbar:
#     with concurrent.futures.ProcessPoolExecutor(max_workers=jobs) as executor:
#         futures = [
#             executor.submit(
#                 process_file,
#                 os.path.join(images_dir, filename),
#                 masks_dir,
#                 lables_dir,
#                 images_out_dir,
#                 masks_out_dir,
#             )
#             for filename in os.listdir(images_dir)
#         ]
#         # Wait for all futures to complete
#         for _ in concurrent.futures.as_completed(futures):
#             pbar.update(1)

In [13]:
# output.show()