In [1]:
import argparse
import concurrent.futures
import os

import fiona
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
from rasterio.features import geometry_mask
from rasterio.transform import from_bounds
from shapely.geometry import mapping
from tqdm import tqdm

In [2]:
labels_dir = "/net/data_ssd/tree_mortality_orthophotos/labels_and_aois/"
out_dir = "/net/scratch/jmoehring/masks/"
metadata_path = "/net/scratch/jmoehring/images_meta.csv"
jobs = 50

In [3]:
def process_file(filename, df_meta):
    if filename.endswith(".gpkg"):
        if filename.replace("_polygons.gpkg", "") in df_meta["filename_map"].values:
            filepath = os.path.join(labels_dir, filename)
            # Get metadata for current gpkg file
            filename_map = filename.replace("_polygons.gpkg", "")
            out_filepath = os.path.join(out_dir, filename_map + "_mask.tif")
            if not os.path.exists(out_filepath):
                file_meta = df_meta.loc[
                    df_meta["filename_map"] == filename_map
                ].to_dict("records")[0]
                out_image = np.zeros(
                    (file_meta["height"], file_meta["width"]), dtype=np.uint8
                )
                transform = from_bounds(
                    north=file_meta["top"],
                    south=file_meta["bottom"],
                    west=file_meta["left"],
                    east=file_meta["right"],
                    width=file_meta["width"],
                    height=file_meta["height"],
                )
                # Read in gpkg file and determine if standing deadwood is present
                layers = fiona.listlayers(filepath)
                if "standing_deadwood" in layers:
                    gdf_label = gpd.read_file(filepath, layer="standing_deadwood")
                    gdf_label = gdf_label.to_crs("EPSG:4326")
                    # Rasterize polygons
                    for _, row in gdf_label.iterrows():
                        geom = mapping(row["geometry"])
                        mask = geometry_mask(
                            [geom],
                            transform=transform,
                            invert=True,
                            out_shape=out_image.shape,
                        )
                        out_image[mask] = 1

                if "brown_trees" in layers:
                    gdf_label = gpd.read_file(filepath, layer="brown_trees")
                    gdf_label = gdf_label.to_crs("EPSG:4326")
                    # Rasterize polygons
                    for _, row in gdf_label.iterrows():
                        geom = mapping(row["geometry"])
                        mask = geometry_mask(
                            [geom],
                            transform=transform,
                            invert=True,
                            out_shape=out_image.shape,
                        )
                        out_image[mask] = 1

                if "parts" in layers:
                    gdf_label = gpd.read_file(filepath, layer="parts")
                    gdf_label = gdf_label.to_crs("EPSG:4326")
                    # Rasterize polygons
                    for _, row in gdf_label.iterrows():
                        geom = mapping(row["geometry"])
                        mask = geometry_mask(
                            [geom],
                            transform=transform,
                            invert=True,
                            out_shape=out_image.shape,
                        )
                        out_image[mask] = 1

                # Save image
                with rasterio.open(
                    out_filepath,
                    "w",
                    driver="GTiff",
                    compress="DEFLATE",
                    height=out_image.shape[0],
                    width=out_image.shape[1],
                    count=1,
                    dtype="uint8",
                    crs="EPSG:4326",
                    transform=transform,
                ) as dst:
                    dst.write(out_image, 1)

In [4]:
# add new filename map column to find metadata for each image
df_meta = pd.read_csv(metadata_path)
df_meta["filename_map"] = df_meta["filename"].str.replace(".tif", "")

In [5]:
with tqdm(total=len(os.listdir(labels_dir))) as pbar:
    with concurrent.futures.ProcessPoolExecutor(max_workers=jobs) as executor:
        futures = [
            executor.submit(process_file, filename, df_meta)
            for filename in os.listdir(labels_dir)
        ]
        # Wait for all futures to complete
        for _ in concurrent.futures.as_completed(futures):
            pbar.update(1)

100%|██████████| 910/910 [00:00<00:00, 1545.26it/s]
