In [25]:
from osgeo import gdal
import os
import numpy as np

image_folder = r'data_original\image'
mask_folder = r'data_original\mask'

out_image_folder = r'data_patches\image'
out_mask_folder = r'data_patches\mask'

# create output folders
os.makedirs(out_image_folder, exist_ok=True)
os.makedirs(out_mask_folder, exist_ok=True)

image_files = os.listdir(image_folder)
mask_files = os.listdir(mask_folder)

# filter out non-png files
image_files = [f for f in image_files if f.endswith('.png')]
mask_files = [f for f in mask_files if f.endswith('.png')]

# sort files
image_files.sort()
mask_files.sort()

len(image_files), len(mask_files)

(24, 24)

In [27]:
from PIL import Image

# Image and Mask sizes are: (7680, 7680) (7680, 7680)
patch_size = 512

for i in range(len(image_files)):
    image_file = image_files[i]
    mask_file = mask_files[i]
    
    image_path = os.path.join(image_folder, image_file)
    mask_path = os.path.join(mask_folder, mask_file)
    
    out_image_path = os.path.join(out_image_folder, image_file)
    out_mask_path = os.path.join(out_mask_folder, mask_file)
    
    image = gdal.Open(image_path)
    mask = gdal.Open(mask_path)
    
    image_array = image.ReadAsArray() # shape: (7680, 7680)
    mask_array = mask.ReadAsArray() # shape: (7680, 7680)

    # set zero values to average of non-zero values int
    zero_values_mask = image_array == 0

    image_array[zero_values_mask] = int(np.mean(image_array[~zero_values_mask]))

    h, w = image_array.shape
    
    for y in range(0, h, patch_size):
        for x in range(0, w, patch_size):
            patch_image = image_array[y:y+patch_size, x:x+patch_size]
            patch_mask = mask_array[y:y+patch_size, x:x+patch_size]
            
            if patch_image.shape[0] != patch_size or patch_image.shape[1] != patch_size:
                continue
            
            if patch_mask.shape[0] != patch_size or patch_mask.shape[1] != patch_size:
                continue
            
            out_image_name = f'{image_file[:16]}_{y}_{x}.png'
            out_mask_name = f'{mask_file[:16]}_{y}_{x}.png'
            
            out_image_path = os.path.join(out_image_folder, out_image_name)
            out_mask_path = os.path.join(out_mask_folder, out_mask_name)

            im = Image.fromarray(patch_image)
            im.save(out_image_path)

            im = Image.fromarray(patch_mask)
            im.save(out_mask_path)

            print(f"Saved: {out_image_path} {out_mask_path}")



Saved: data_patches\image\thm_dir_N-30_000_0_0.png data_patches\mask\thm_dir_N-30_000_0_0.png
Saved: data_patches\image\thm_dir_N-30_000_0_512.png data_patches\mask\thm_dir_N-30_000_0_512.png
Saved: data_patches\image\thm_dir_N-30_000_0_1024.png data_patches\mask\thm_dir_N-30_000_0_1024.png
Saved: data_patches\image\thm_dir_N-30_000_0_1536.png data_patches\mask\thm_dir_N-30_000_0_1536.png
Saved: data_patches\image\thm_dir_N-30_000_0_2048.png data_patches\mask\thm_dir_N-30_000_0_2048.png
Saved: data_patches\image\thm_dir_N-30_000_0_2560.png data_patches\mask\thm_dir_N-30_000_0_2560.png
Saved: data_patches\image\thm_dir_N-30_000_0_3072.png data_patches\mask\thm_dir_N-30_000_0_3072.png
Saved: data_patches\image\thm_dir_N-30_000_0_3584.png data_patches\mask\thm_dir_N-30_000_0_3584.png
Saved: data_patches\image\thm_dir_N-30_000_0_4096.png data_patches\mask\thm_dir_N-30_000_0_4096.png
Saved: data_patches\image\thm_dir_N-30_000_0_4608.png data_patches\mask\thm_dir_N-30_000_0_4608.png
Saved: d