In [None]:
import os
import glob
import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.merge import merge
from osgeo import gdal
from tensorflow.keras.models import load_model

In [None]:
###########################Raster tiling##########################

#Tile a large GeoTIFF into overlapping 128×128 chips for classification
def process_raster(input_file, output_dir, window_size=(128, 128), stride=112):
    with rasterio.open(input_file) as src:
        meta = src.meta.copy()  # Copy original metadata
        counter = 0  # Tile index for filename suffix
        
        # Iterate over top-left coordinates of each window.
        # Ensures windows remain fully inside the raster bounds.
        for j in range(0, src.height - window_size[1] + 1, stride):
            for i in range(0, src.width - window_size[0] + 1, stride):
                window = Window(i, j, *window_size)
                
                # Read only the pixels inside the current window.
                data = src.read(window=window)
                
                # Update metadata
                meta.update({
                    "driver": "GTiff",
                    "height": window_size[1],
                    "width": window_size[0],
                    "transform": rasterio.windows.transform(window, src.transform)
                })
                
                # Save tiles
                filename = os.path.join(
                    output_dir, f"{os.path.splitext(os.path.basename(input_file))[0]}_{counter}.tif")
                with rasterio.open(filename, 'w', **meta) as dst:
                    dst.write(data)
                
                counter += 1  # Increment tile counter after each write

def main():
    input_file = 'Path to the source raster.tif'  
    output_dir = 'Directory to store output tiles'  
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Run the tiling process
    process_raster(input_file, output_dir)

if __name__ == "__main__":
    main()

In [None]:
###########################Prediction##########################

# Specify the GPU to use
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

#Write a single-band GeoTIFF using GDAL, copying georeferencing from a source dataset.
def write_geotiff(filename, arr, im):
  arr_type = gdal.GDT_Int32 # Save predicted labels as 32-bit signed integers
  driver = gdal.GetDriverByName("GTiff")
    
  # Create a 1-band raster with the same width/height as the label array
  out_im = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
  out_im.SetProjection(im.GetProjection())
  out_im.SetGeoTransform(im.GetGeoTransform())
    
  # Write the array to band 1 and compute basic statistics
  band = out_im.GetRasterBand(1)
  band.WriteArray(arr)
  band.FlushCache()
  band.ComputeStatistics(False)

# Load the trained segmentation model
model = load_model('F:/Topic_three/result/save/S123.h5')

# Define input (tiles) and output (predictions) folders
input_folder = 'link to your own'
output_folder = 'link to your own'

os.chdir("G:/classresult/")

# Iterate all tiled GeoTIFFs in the input folder
for image_file in os.listdir(input_folder):
    if image_file.endswith(".tif"):
        image_path = os.path.join(input_folder, image_file)

        # Open tile with GDAL
        image_ds = gdal.Open(image_path, gdal.GA_ReadOnly)

        # Read image and convert to NumPy
        image_data = image_ds.ReadAsArray()
        
        # Skip purely empty tiles.
        if np.all(image_data == 0):
            print(f'Skipped: {image_file} (All values are 0)')
            continue  
        
        # Normalization and axis transformation consistent with the training phase
        image_data = image_data.astype('float32') / 12000
        image_data = image_data.transpose(1, 2, 0)

        # Model prediction
        predicted_segmentation = model.predict(np.expand_dims(image_data, axis=0))

        # Select the category with the highest probability as the prediction for each pixel
        predicted_labels = np.argmax(predicted_segmentation, axis=-1).reshape(128,128)

        write_geotiff(image_file, predicted_labels, image_ds)

        print(f'Segmented and saved: {image_file}')

In [None]:
###########################Remove overlapping edges##########################

def crop_tiffs(input_dir, output_dir, edge_width=8):
    # Search for all .tif files
    search_files = os.path.join(input_dir, "*.tif")
    file_list = [file for file in glob.glob(search_files) if os.path.isfile(file)]
    
    # Ensure the output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Open each .tif and crop a fixed-width border
    for fp in file_list:
        with rasterio.open(fp) as src:
            data = src.read()  # 读取数据
            
            # Remove a border of `edge_width` pixels
            cropped_data = data[:, edge_width:-edge_width, edge_width:-edge_width]
            
            # Prepare metadata
            out_meta = src.meta.copy()
            out_meta.update({
                "height": cropped_data.shape[1],
                "width": cropped_data.shape[2],
                "transform": rasterio.windows.transform(
                    rasterio.windows.Window(edge_width, edge_width, src.width - 2 * edge_width, src.height - 2 * edge_width),
                    src.transform
                )
            })
            
            # Build output path
            output_path = os.path.join(output_dir, os.path.basename(fp))
            
             Write the cropped data to a new GeoTIFF
            with rasterio.open(output_path, "w", **out_meta) as dest:
                dest.write(cropped_data)
    
    print(f"All TIFFs have been cropped and saved")

# Invoke the function
input_directory = 'link to your own'  
output_directory = 'link to your own'  
crop_tiffs(input_directory, output_directory)

In [None]:
###########################Mosaic##########################

def merge_tiffs(input_dir, output_file):
    # Search for all .tif files
    search_files = os.path.join(input_dir, "*.tif")
    file_list = [file for file in glob.glob(search_files) if os.path.isfile(file)]
    
    # Open each .tif
    src_files_to_mosaic = []
    for fp in file_list:
        src = rasterio.open(fp)
        src_files_to_mosaic.append(src)
    
    # Merge (mosaic) all input rasters into a single array
    mosaic, out_trans = merge(src_files_to_mosaic)
    
    # Prepare output metadata
    out_meta = src.meta.copy()
    out_meta.update({
        "driver": "GTiff",
        "height": mosaic.shape[1],
        "width": mosaic.shape[2],
        "transform": out_trans,
        "compress": "lzw"
    })
    
    # Write the merged array to a GeoTIFF
    with rasterio.open(output_file, "w", **out_meta) as dest:
        dest.write(mosaic)
    for src in src_files_to_mosaic:
        src.close()

    print(f"Merged TIFF saved")

# Invoke the function
input_directory = 'link to your own'  
output_tiff = 'link to your own'  
merge_tiffs(input_directory, output_tiff)
