In [1]:
import os
import sys

import geopandas as gpd
import pandas as pd
import numpy as np
import rasterio
from rasterio import windows
from rasterio.features import geometry_window
import utm
from pyproj import Transformer
from tqdm import tqdm

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from utils.parallel import paral


os.environ["GDAL_PAM_ENABLED"] = "NO"

In [2]:
tile_width = 1024
tile_height = 1024
tile_overlap = 512

cell_widths = np.arange(0.02, 0.28, 0.02)

images_dir = "/net/data_ssd/tree_mortality_orthophotos/orthophotos/"
masks_dir = "/net/scratch/cmosig/segmentation_meta/masks/"
labels_dir = "/net/data_ssd/tree_mortality_orthophotos/labels_and_aois/"
metadata_path = "/net/scratch/cmosig/segmentation_meta/metadata_manual_with_resolution.csv"
tiles_out_dir = "/net/scratch/cmosig/segmentation_meta/tiles_2025/"

cores = 128

In [3]:
def shorten_list(arr, target):
    index = next((i for i, x in enumerate(arr) if x > target), None)
    if index is not None:
        shortened_list = arr[index:]
        return shortened_list
    else:
        return []

In [4]:
def get_utm_crs(dr):
    # Define UTM ranges for WGS84 and ETRS89
    utm_northern_wgs84_range = range(32601, 32661)
    utm_southern_wgs84_range = range(32701, 32761)
    utm_northern_etrs89_range = range(25801, 25861)
    epsg_code = dr.crs.to_epsg()

    if epsg_code in utm_northern_wgs84_range or epsg_code in utm_southern_wgs84_range:
        return dr.crs
    if epsg_code in utm_northern_etrs89_range:
        return dr.crs

    # Extract the lat and lon from the dataset transform (assuming affine transform)
    lon, lat = dr.transform[2], dr.transform[5]
    if epsg_code != 4326:
        transformer = Transformer.from_crs(
            f"epsg:{epsg_code}", "epsg:4326", always_xy=True
        )
        lon, lat = transformer.transform(dr.transform[2], dr.transform[5])

    _, _, zone_number, zone_letter = utm.from_latlon(lat, lon)
    # Determine if the point is in the northern or southern hemisphere
    if zone_letter >= "N":
        # Northern hemisphere
        utm_code = 32600 + zone_number  # Default to WGS84 Northern Hemisphere
    else:
        # Southern hemisphere
        utm_code = 32700 + zone_number  # WGS84 Southern Hemisphere
    utm_crs = f"EPSG:{utm_code}"
    return utm_crs

In [5]:
def reproject_dataset_to_utm(dataset, resampling_method):
    utm_crs = get_utm_crs(dataset)
    default_transform, width, height = rasterio.warp.calculate_default_transform(
        dataset.crs, utm_crs, dataset.width, dataset.height, *dataset.bounds
    )
    kwargs = dataset.meta.copy()
    kwargs.update(
        {
            "crs": utm_crs,
            "transform": default_transform,
            "width": width,
            "height": height,
        }
    )
    memfile = rasterio.io.MemoryFile()
    with memfile.open(**kwargs, compress="DEFLATE") as dst:
        for i in range(1, dataset.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(dataset, i),
                destination=rasterio.band(dst, i),
                src_transform=dataset.transform,
                src_crs=dataset.crs,
                dst_transform=default_transform,
                dst_crs=utm_crs,
                resampling=resampling_method,
            )
    return memfile

In [6]:
def rescale_dataset_to_cell_width(dataset, cell_width, resampling_method):
    kwargs = dataset.meta.copy()
    kwargs.update(
        {
            "width": int(
                np.ceil(dataset.width * abs(dataset.transform.a) / cell_width)
            ),
            "height": int(
                np.ceil(dataset.height * abs(dataset.transform.e) / cell_width)
            ),
            "transform": rasterio.Affine(
                cell_width,
                0.0,
                dataset.transform.c,
                0.0,
                -cell_width,
                dataset.transform.f,
            ),
        }
    )
    memfile = rasterio.io.MemoryFile()
    with memfile.open(
        **kwargs,
        compress="DEFLATE",
        tiled=True,
        blockxsize=tile_overlap,
        blockysize=tile_overlap
    ) as dst:
        for i in range(1, dataset.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(dataset, i),
                destination=rasterio.band(dst, i),
                src_transform=dataset.transform,
                src_crs=dataset.crs,
                dst_transform=kwargs["transform"],
                dst_crs=dataset.crs,
                resampling=resampling_method,
            )
    return memfile

In [7]:
metadata_df = pd.read_csv(metadata_path)

In [8]:
def process_file(mask_filename):
    image_filename = mask_filename.replace("_mask.tif", ".tif")
    mask_filepath = os.path.join(masks_dir, mask_filename)

    register_rows = []
    image_filepath = os.path.join(images_dir, image_filename)

    file_meta = metadata_df[metadata_df["filename"] == image_filename].to_dict(
        "records"
    )[0]

    label_filename = image_filename.replace(".tif", "_polygons.gpkg")
    label_filepath = os.path.join(labels_dir, label_filename)
    image_out_dir = os.path.join(tiles_out_dir, image_filename.replace(".tif", ""))
    if os.path.exists(image_out_dir):
        return

    with rasterio.open(image_filepath) as idr, rasterio.open(mask_filepath) as mdr:
        metric_transform = rasterio.warp.calculate_default_transform(
            src_crs=idr.crs,
            dst_crs="EPSG:8857",
            width=idr.width,
            height=idr.height,
            left=idr.bounds[0],
            bottom=idr.bounds[1],
            right=idr.bounds[2],
            top=idr.bounds[3],
        )
        if metric_transform[0][0] > max(cell_widths):
            print(f"Skipping {image_filename} due to resolution")
            return
        image_memfile = reproject_dataset_to_utm(
            idr, rasterio.enums.Resampling.bilinear
        )
        idr.close()
        mask_memfile = reproject_dataset_to_utm(mdr, rasterio.enums.Resampling.nearest)
        mdr.close()
        image_repojected = rasterio.open(image_memfile, "r+")
        mask_repojected = rasterio.open(mask_memfile, "r+")

        gdf_label = gpd.read_file(label_filepath, layer="aoi")
        gdf_label = gdf_label.to_crs(image_repojected.crs)

        for cell_width in sorted(
            np.insert(
                shorten_list(cell_widths, abs(image_repojected.transform[0])),
                0,
                abs(image_repojected.transform[0]),
            )
        ):
            cell_width = round(cell_width, 3)

            resolution_out_dir = os.path.join(image_out_dir, str(cell_width))
            image_rescaled_memfile = rescale_dataset_to_cell_width(
                image_repojected, cell_width, rasterio.enums.Resampling.bilinear
            )
            mask_rescaled_memfile = rescale_dataset_to_cell_width(
                mask_repojected, cell_width, rasterio.enums.Resampling.nearest
            )
            image_rescaled = rasterio.open(image_rescaled_memfile, "r+")
            mask_rescaled = rasterio.open(mask_rescaled_memfile, "r+")

            for _, aoi_row in gdf_label.iterrows():
                aoi_window = geometry_window(image_rescaled, [aoi_row.geometry])
                xmin, ymin = aoi_window.col_off, aoi_window.row_off
                xmax, ymax = xmin + aoi_window.width, ymin + aoi_window.height

                for window in get_windows(
                    xmin, ymin, xmax, ymax, tile_width, tile_height, tile_overlap
                ):
                    window_transform = windows.transform(
                        window, image_rescaled.transform
                    )
                    image_tile_metadata = image_rescaled.meta.copy()
                    image_tile_metadata.update(
                        {
                            "transform": window_transform,
                            "width": window.width,
                            "height": window.height,
                        }
                    )
                    mask_tile_metadata = mask_rescaled.meta.copy()
                    mask_tile_metadata.update(
                        {
                            "transform": window_transform,
                            "width": window.width,
                            "height": window.height,
                        }
                    )
                    image_tile_name = f"{window.col_off}_{window.row_off}.tif"
                    mask_tile_name = f"{window.col_off}_{window.row_off}_mask.tif"
                    image_tile_path = os.path.join(resolution_out_dir, image_tile_name)
                    mask_tile_path = os.path.join(resolution_out_dir, mask_tile_name)

                    if os.path.exists(image_tile_path) or os.path.exists(
                        mask_tile_path
                    ):
                        continue
                    out_image = image_rescaled.read(window=window)
                    out_mask = mask_rescaled.read(window=window)
                    filled_fraction = 0
                    if mask_tile_metadata["nodata"] is None:
                        filled_fraction = np.count_nonzero(out_mask) / out_mask.size
                    else:
                        filled_fraction = (
                            np.count_nonzero(out_mask != mask_tile_metadata["nodata"])
                            / out_mask.size
                        )

                    if filled_fraction >= 0.3:
                        os.makedirs(image_tile_path.rsplit("/", 1)[0], exist_ok=True)
                        with rasterio.open(
                            image_tile_path,
                            "w",
                            **image_tile_metadata,
                            compress="DEFLATE",
                        ) as dst:
                            dst.write(out_image)
                            dst.close()
                        with rasterio.open(
                            mask_tile_path,
                            "w",
                            **mask_tile_metadata,
                            compress="DEFLATE",
                        ) as dst:
                            dst.write(out_mask)
                            dst.close()
                        register_rows.append(
                            {
                                "base_image_name": image_filename,
                                "image_path": image_tile_path,
                                "mask_path": mask_tile_path,
                                "resolution": cell_width,
                                "x": window.col_off,
                                "y": window.row_off,
                                "label_quality": file_meta["label_quality"],
                            }
                        )

            image_rescaled.close()
            mask_rescaled.close()
            image_rescaled_memfile.close()
            mask_rescaled_memfile.close()

        image_repojected.close()
        mask_repojected.close()
        image_memfile.close()
        mask_memfile.close()

    return register_rows

In [9]:
results = paral(process_file, [os.listdir(masks_dir)], num_cores=cores)

process_file:  19%|█▉        | 90/463 [00:09<00:24, 15.00jobs/s] 

Skipping california_cropped_18_2020_8_5.tif due to resolution
Skipping california_cropped_7_2020_7_26.tif due to resolution
Skipping resampled_swissimage-dop10_2018_2590-1136_0.tif due to resolution
Skipping california_cropped_15_2020_8_2.tif due to resolution
Skipping california_cropped_1_2020_8_5.tif due to resolution
Skipping california_cropped_13_2020_8_2.tif due to resolution


process_file:  32%|███▏      | 146/463 [00:09<00:09, 32.75jobs/s]

Skipping california_cropped_14_2020_8_2.tif due to resolution
Skipping resampled_swissimage-dop10_2018_2692-1179_0.tif due to resolution
Skipping resampled_swissimage-dop10_2022_2696-1215_0.tif due to resolution
Skipping california_cropped_19_2020_8_5.tif due to resolution
Skipping resampled_swissimage-dop10_2020_2576-1114_0.tif due to resolution
Skipping california_cropped_0_2020_8_5.tif due to resolution
Skipping resampled_swissimage-dop10_2022_2766-1213_0.tif due to resolution


process_file:  50%|████▉     | 231/463 [00:10<00:02, 83.01jobs/s]

Skipping california_cropped_6_2020_8_2.tif due to resolution
Skipping california_cropped_12_2020_8_2.tif due to resolution
Skipping california_cropped_11_2020_8_5.tif due to resolution
Skipping california_cropped_8_2020_8_2.tif due to resolution


process_file:  62%|██████▏   | 286/463 [00:10<00:01, 125.93jobs/s]

Skipping california_cropped_3_2020_8_2.tif due to resolution
Skipping california_cropped_5_2020_8_2.tif due to resolution
Skipping california_cropped_4_2020_8_3.tif due to resolution
Skipping california_cropped_17_2020_8_5.tif due to resolution


process_file:  82%|████████▏ | 381/463 [00:10<00:00, 215.14jobs/s]

Skipping california_cropped_16_2020_8_2.tif due to resolution
Skipping california_cropped_9_2020_8_2.tif due to resolution
Skipping california_cropped_10_2020_8_2.tif due to resolution
Skipping resampled_swissimage-dop10_2020_2587-1131_0.tif due to resolution
Skipping resampled_swissimage-dop10_2018_2594-1151_0.tif due to resolution
Skipping resampled_swissimage-dop10_2021_2686-1139_0.tif due to resolution
Skipping resampled_swissimage-dop10_2021_2633-1157_0.tif due to resolution
Skipping california_cropped_2_2020_8_5.tif due to resolution


process_file: 100%|█████████▉| 462/463 [00:21<00:00, 174.15jobs/s]