This notebook takes in binary classification predictions for each tile and outputs an entire basin prediction visualization as a GeoTiff

In [None]:
from osgeo import gdal
from osgeo import gdalconst
import glob
import numpy as np
import os
import pandas as pd

basin_name = 'Af_126767'
probability_mode = False

In [None]:
#First iteration is noChannel, then it is channel

# set the folder path where the TIFF files are located
folder_paths = ['No_Channel/', 'Channel/']

output_folder_name = "predictedTiles/"
output_folder_name_resized = "predictedTilesResized/"
filename_ending = "predictedTile.tif"

new_size = 64 / 0.7

df = pd.read_csv("resnet_50_prc_tilewise_all_predictions.csv") # Alternatively "resnet_50_prc_basinwise_test_predictions.csv"

Resizing the tiles as well as assigning each tile either their predicted class or the probability of being a channel

In [None]:
missingCount = 0
notMissingCount = 0

# Create the output folder if it doesn't exist
if not os.path.exists(output_folder_name):
        os.makedirs(output_folder_name)
if not os.path.exists(output_folder_name_resized):
        os.makedirs(output_folder_name_resized)


for i in range(len(folder_paths)):
    folder_path = folder_paths[i]
    # create a list of file paths for all TIFF files in the folder
    tiff_files = glob.glob(folder_path + '*.tif')

    # loop through each TIFF file and read its raster data
    for count, input_file in enumerate(tiff_files):
        input_ds = gdal.Open(input_file, gdal.GA_ReadOnly)

        # get the number of rows and columns in the input raster
        cols = input_ds.RasterXSize
        rows = input_ds.RasterYSize

        # create an output raster with one band that is all zeros
        driver = gdal.GetDriverByName("GTiff")
        filename = output_folder_name + str(count) + filename_ending
        output_file = filename
        output_ds = driver.Create(output_file, cols, rows, 1, gdal.GDT_Float32)
        output_band = output_ds.GetRasterBand(1)


        predicted_label = -1
        file_core = input_file[len(folder_path):-3] # file_core is the filename without the extension
        is_present = df['filename'].str.contains(file_core, regex=False)
        row_for_file = df[is_present]
        if not row_for_file.empty:
                if probability_mode:
                        predicted_label = row_for_file['probability']
                        predicted_label = float(predicted_label)
                else:
                        predicted_label = row_for_file['predictions']
                        predicted_label = int(predicted_label) # Requires that there is only one row in row_for_file
        else:
                missingCount += 1
                input_ds = None
                output_ds = None
                continue

        if probability_mode:
               output_array = np.full((rows, cols), predicted_label, dtype=np.float32)
        else:
                if predicted_label == 0:
                        output_array = np.zeros((rows, cols), dtype=np.float32)
                elif predicted_label == 1:
                        output_array = np.ones((rows, cols), dtype=np.float32)


        notMissingCount += 1

        # write the output array to the output raster band
        output_band.WriteArray(output_array)

        # set the geotransform and projection for the output raster
        output_ds.SetGeoTransform(input_ds.GetGeoTransform())
        output_ds.SetProjection(input_ds.GetProjection())



        # close the raster datasets
        input_ds = None
        output_ds = None

        xres = cols / new_size
        yres = rows / new_size
        output_filename = output_folder_name_resized + str(count) + filename_ending
        gdal.Warp(output_filename, filename, xRes=xres, yRes=yres, resampleAlg=gdal.GRA_Max)

print(missingCount)
print(notMissingCount)

2
2158


Creating the final merged raster from the resized tiles

In [None]:
input_folder = output_folder_name_resized
output_raster = basin_name + '_predictions.tif'

# List all the raster files in the input folder
raster_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.endswith('.tif')]

# Build a virtual raster from the input rasters
vrt_options = gdal.BuildVRTOptions(resolution='average', addAlpha=False)
vrt_ds = gdal.BuildVRT('temp.vrt', raster_files, options=vrt_options)

# Set the output raster properties based on the virtual raster
x_min, x_res, x_skew, y_min, y_skew, y_res = vrt_ds.GetGeoTransform()
x_size = vrt_ds.RasterXSize
y_size = vrt_ds.RasterYSize
n_bands = vrt_ds.RasterCount
band_type = vrt_ds.GetRasterBand(1).DataType

# Create the output raster and write the virtual raster data to it
driver = gdal.GetDriverByName('GTiff')
output_ds = driver.Create(output_raster, x_size, y_size, n_bands, band_type)
output_ds.SetGeoTransform((x_min, x_res, x_skew, y_min, y_skew, y_res))
for i in range(1, n_bands+1):
    band = vrt_ds.GetRasterBand(i)
    output_band = output_ds.GetRasterBand(i)
    output_band.WriteArray(band.ReadAsArray())

# Set the output raster projection
output_ds.SetProjection(vrt_ds.GetProjection())

# Clean up
output_ds = None
vrt_ds = None
os.remove('temp.vrt')