In [1]:
import os
import sys
import concurrent.futures

import geopandas as gpd
import pandas as pd
import numpy as np
import rasterio
from rasterio import windows
from rasterio.features import geometry_window
from tqdm import tqdm
import utm

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from utils.parallel import paral
from unet.dataset import get_windows


os.environ["GDAL_PAM_ENABLED"] = "NO"
os.environ["GDAL_NUM_THREADS"] = "50"

In [2]:
tile_width = 1024
tile_height = 1024
tile_overlap = 512

cell_widths = np.arange(0.02, 0.22, 0.02)

images_dir = "/net/data_ssd/tree_mortality_orthophotos/orthophotos/"
masks_dir = "/net/scratch/jmoehring/masks/"
labels_dir = "/net/data_ssd/tree_mortality_orthophotos/labels_and_aois/"
metadata_path = "/net/scratch/jmoehring/metadata_manual_with_resolution.csv"

tiles_out_dir = "/net/scratch/jmoehring/tiles_1024/"

cores = 30

In [3]:
def shorten_list(arr, target):
    index = next((i for i, x in enumerate(arr) if x > target), None)
    if index is not None:
        shortened_list = arr[index:]
        return shortened_list
    else:
        return []

In [10]:
def get_utm_crs(dr):
    utm_northern_range = range(32601, 32661)
    utm_southern_range = range(32701, 32761)
    epsg_code = dr.crs.to_epsg()
    if epsg_code in utm_northern_range or epsg_code in utm_southern_range:
        return dr.crs
    if epsg_code == 25832:
        return dr.crs
    if epsg_code == 4326:
        zone = utm.from_latlon(dr.transform[0], dr.transform[2])
        utm_code = 32600 + zone[2]
        if zone[3] == "S":
            utm_code += 100
        utm_crs = f"EPSG:{utm_code}"
        return utm_crs
    else:
        raise ValueError(f"Unknown CRS: {epsg_code}")

In [11]:
def reproject_dataset_to_utm(dataset, resampling_method):
    utm_crs = get_utm_crs(dataset)
    default_transform, width, height = rasterio.warp.calculate_default_transform(
        dataset.crs, utm_crs, dataset.width, dataset.height, *dataset.bounds
    )
    kwargs = dataset.meta.copy()
    kwargs.update(
        {
            "crs": utm_crs,
            "transform": default_transform,
            "width": width,
            "height": height,
        }
    )
    memfile = rasterio.io.MemoryFile()
    with memfile.open(**kwargs, compress="DEFLATE") as dst:
        for i in range(1, dataset.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(dataset, i),
                destination=rasterio.band(dst, i),
                src_transform=dataset.transform,
                src_crs=dataset.crs,
                dst_transform=default_transform,
                dst_crs=utm_crs,
                resampling=resampling_method,
            )
    return memfile

In [12]:
def rescale_dataset_to_cell_width(dataset, cell_width, resampling_method):
    kwargs = dataset.meta.copy()
    kwargs.update(
        {
            "width": int(
                np.ceil(dataset.width * abs(dataset.transform.a) / cell_width)
            ),
            "height": int(
                np.ceil(dataset.height * abs(dataset.transform.e) / cell_width)
            ),
            "transform": rasterio.Affine(
                cell_width,
                0.0,
                dataset.transform.c,
                0.0,
                -cell_width,
                dataset.transform.f,
            ),
        }
    )
    memfile = rasterio.io.MemoryFile()
    with memfile.open(
        **kwargs,
        compress="DEFLATE",
        tiled=True,
        blockxsize=tile_overlap,
        blockysize=tile_overlap
    ) as dst:
        for i in range(1, dataset.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(dataset, i),
                destination=rasterio.band(dst, i),
                src_transform=dataset.transform,
                src_crs=dataset.crs,
                dst_transform=kwargs["transform"],
                dst_crs=dataset.crs,
                resampling=resampling_method,
            )
    return memfile

In [13]:
metadata_df = pd.read_csv(metadata_path)

In [14]:
def process_file(mask_filename):
    image_filename = mask_filename.replace("_mask.tif", ".tif")
    mask_filepath = os.path.join(masks_dir, mask_filename)

    register_rows = []
    if os.path.exists(os.path.join(tiles_out_dir, image_filename.replace(".tif", ""))):
        return
    image_filepath = os.path.join(images_dir, image_filename)

    file_meta = metadata_df.loc[metadata_df["filename"] == image_filename].to_dict(
        "records"
    )[0]

    label_filename = image_filename.replace(".tif", "_polygons.gpkg")
    label_filepath = os.path.join(labels_dir, label_filename)

    with rasterio.open(image_filepath) as idr, rasterio.open(mask_filepath) as mdr:
        image_memfile = reproject_dataset_to_utm(
            idr, rasterio.enums.Resampling.bilinear
        )
        idr.close()
        mask_memfile = reproject_dataset_to_utm(mdr, rasterio.enums.Resampling.nearest)
        mdr.close()
        image_repojected = rasterio.open(image_memfile, "r+")
        mask_repojected = rasterio.open(mask_memfile, "r+")

        gdf_label = gpd.read_file(label_filepath, layer="aoi")
        gdf_label = gdf_label.to_crs(image_repojected.crs)

        image_out_dir = os.path.join(tiles_out_dir, image_filename.replace(".tif", ""))
        os.makedirs(image_out_dir, exist_ok=True)

        for cell_width in sorted(
            np.insert(
                shorten_list(cell_widths, abs(image_repojected.transform[0])),
                0,
                abs(image_repojected.transform[0]),
            )
        ):
            cell_width = round(cell_width, 3)

            resolution_out_dir = os.path.join(image_out_dir, str(cell_width))
            os.makedirs(resolution_out_dir, exist_ok=True)

            image_rescaled_memfile = rescale_dataset_to_cell_width(
                image_repojected, cell_width, rasterio.enums.Resampling.bilinear
            )
            mask_rescaled_memfile = rescale_dataset_to_cell_width(
                mask_repojected, cell_width, rasterio.enums.Resampling.nearest
            )
            image_rescaled = rasterio.open(image_rescaled_memfile, "r+")
            mask_rescaled = rasterio.open(mask_rescaled_memfile, "r+")

            for _, aoi_row in gdf_label.iterrows():
                aoi_window = geometry_window(image_rescaled, [aoi_row.geometry])
                xmin, ymin = aoi_window.col_off, aoi_window.row_off
                xmax, ymax = xmin + aoi_window.width, ymin + aoi_window.height

                for window in get_windows(
                    xmin, ymin, xmax, ymax, tile_width, tile_height, tile_overlap
                ):
                    window_transform = windows.transform(
                        window, image_rescaled.transform
                    )
                    image_tile_metadata = image_rescaled.meta.copy()
                    image_tile_metadata.update(
                        {
                            "transform": window_transform,
                            "width": window.width,
                            "height": window.height,
                        }
                    )
                    mask_tile_metadata = mask_rescaled.meta.copy()
                    mask_tile_metadata.update(
                        {
                            "transform": window_transform,
                            "width": window.width,
                            "height": window.height,
                        }
                    )
                    image_tile_name = f"{window.col_off}_{window.row_off}.tif"
                    mask_tile_name = f"{window.col_off}_{window.row_off}_mask.tif"
                    image_tile_path = os.path.join(resolution_out_dir, image_tile_name)
                    mask_tile_path = os.path.join(resolution_out_dir, mask_tile_name)

                    if os.path.exists(image_tile_path) or os.path.exists(
                        mask_tile_path
                    ):
                        continue
                    out_image = image_rescaled.read(window=window)
                    filled_fraction = 0
                    if image_tile_metadata["nodata"] is None:
                        filled_fraction = np.count_nonzero(out_image) / out_image.size
                    else:
                        filled_fraction = (
                            np.count_nonzero(out_image != image_tile_metadata["nodata"])
                            / out_image.size
                        )

                    if filled_fraction > 0.99:
                        out_mask = mask_rescaled.read(window=window)
                        with rasterio.open(
                            image_tile_path,
                            "w",
                            **image_tile_metadata,
                            compress="DEFLATE",
                        ) as dst:
                            dst.write(out_image)
                            dst.close()
                        with rasterio.open(
                            mask_tile_path,
                            "w",
                            **mask_tile_metadata,
                            compress="DEFLATE",
                        ) as dst:
                            dst.write(out_mask)
                            dst.close()
                        register_rows.append(
                            {
                                "base_image_name": image_filename,
                                "image_path": image_tile_path,
                                "mask_path": mask_tile_path,
                                "resolution": cell_width,
                                "x": window.col_off,
                                "y": window.row_off,
                                "label_quality": file_meta["label_quality"],
                            }
                        )

            image_rescaled.close()
            mask_rescaled.close()
            image_rescaled_memfile.close()
            mask_rescaled_memfile.close()

        image_repojected.close()
        mask_repojected.close()
        image_memfile.close()
        mask_memfile.close()

    return register_rows

In [15]:
results = paral(process_file, [os.listdir(masks_dir)], num_cores=cores)

process_file: 100%|██████████| 259/259 [01:40<00:00,  2.57jobs/s] 
