In [1]:
import os
import re
import geopandas as gpd
import rasterio as rio
from rasterio.mask import mask
from rasterio.windows import from_bounds, transform

In [2]:
output_path = './data/tree_classification'
use_background_pixels = False

if not os.path.exists(output_path):
    os.mkdir(output_path)

class_map = {}

geotiff_paths = ['./data/2021-06-17/zone1/2021-06-17-sbl-z1-rgb-cog.tif', './data/2021-06-17/zone2/2021-06-17-sbl-z2-rgb-cog.tif', './data/2021-06-17/zone3/2021-06-17-sbl-z3-rgb-cog.tif']
shapefile_paths = ['./data/Z1_polygons.gpkg', './data/Z2_polygons.gpkg', './data/Z3_polygons.gpkg']


for file, shapes in zip(geotiff_paths, shapefile_paths):
    zone = re.search(r'Z\d', shapes).group(0)
    with rio.open(file) as tif:

        gdf = gpd.read_file(shapes)
        for j, shape in gdf.iterrows():
                geometry = shape['geometry']
                class_label = shape['Label']

                try:
                    if use_background_pixels:
                        minx, miny, maxx, maxy = geometry.bounds
                        bbox = ((minx, miny), (maxx, maxy))
                        window = from_bounds(minx, miny, maxx, maxy, tif.transform)
                        out_image = tif.read(window=window)
                        out_transform = transform(window, tif.transform)
                    else:
                        out_image, out_transform = mask(tif, [geometry], crop=True)
                except:
                     print(f'error masking shape {j} in zone {zone}')
                     continue
                out_meta = tif.meta.copy()
                out_meta.update({
                    "driver": "GTiff",
                    "height": out_image.shape[1],
                    "width": out_image.shape[2],
                    "transform": out_transform
                })

                class_path = os.path.join(output_path, class_label)
                if not os.path.exists(class_path):
                    os.mkdir(class_path)
                    class_map[class_label] = 1
                else:
                    class_map[class_label] += 1


                file_path = os.path.join(class_path, f"{zone}_{class_map[class_label]}.tif")
                with rio.open(file_path, "w", **out_meta) as dest:
                    dest.write(out_image)