In [4]:
from patchify import patchify
import rasterio
import os
import numpy as np

patch_size = 256*6
step = 256*6

# Define a function to read raster images using rasterio
def read_raster(file_path):
    with rasterio.open(file_path) as src:
        return src.read(), src.meta, src.transform

# Function to save patches using rasterio
def save_patch(data, meta, transform, patch_path):
    meta.update({
        "height": data.shape[1],
        "width": data.shape[2],
        "transform": transform
    })
    with rasterio.open(patch_path, 'w', **meta) as dst:
        dst.write(data)

mask_path = r"D:/Hesham/CUAHSI/Geospatial/Datasets/Forecasts/Beryl/test_patch"
for msk_name in os.listdir(mask_path):
    msk_path = os.path.join(mask_path, msk_name)
    mask_data, mask_meta, mask_transform = read_raster(msk_path)  # Read mask and metadata

    # Remove the band dimension for patchify if it's a single band
    mask_patches = patchify(mask_data[0], (patch_size, patch_size), step=step)

    for i in range(mask_patches.shape[0]):
        for j in range(mask_patches.shape[1]):
            patch = mask_patches[i, j]
            patch = np.expand_dims(patch, axis=0)  # Add band dimension back

            # Calculate the new transform for the patch
            patch_transform = mask_transform * rasterio.Affine.translation(j * step, i * step)

            patch_name = f"mask_{msk_name}_{i}_{j}.tif"
            patch_path = os.path.join(mask_path, patch_name)
            save_patch(patch, mask_meta, patch_transform, patch_path)

