In [None]:
"""Conceptual approach

iterate over every patch
    read georeference information from original patch              retrieve_geoinfo()
    write georeference information into classified patch           out_dataset_creation()
    only use the middle portion of the raster to stop overlap      out_band_write()

merge the patches into one raster                                  merge_rasters()
"""

In [None]:
# imports
from osgeo import gdal
import os
import shutil

# make GDAL throw Python exceptions and get Geotiff driver
print(gdal.__version__)
gdal.UseExceptions()
out_driver = gdal.GetDriverByName('Gtiff')

In [None]:
# Defining paths
root_path = 'E:\\datasets\\test_unet\\Krkonose2012\\overlap'
imagery_path = os.path.join(root_path, 'MHS')
results_path = os.path.join(root_path, 'Results')
final_filepath = os.path.join(root_path, 'KrakonosNet_Classified.tif')

# Creating temporary path
temp_path = os.path.join(results_path, 'temp_geocoding')
try:
    os.mkdir(temp_path)
except FileExistsError:
    print(f'Path to temporary files already exists, delete the following directory first:\n{temp_path}')

In [None]:
def retrieve_geoinfo(img_path):
    img = gdal.Open(img_path)
    geoinfo = {}
    geoinfo['projection'] = img.GetProjection()
    geoinfo['YSize'] = int(img.RasterYSize/2)
    geoinfo['XSize'] = int(img.RasterXSize/2)
    
    old_geotransform = img.GetGeoTransform()
    new_geotransform = (old_geotransform[0] + old_geotransform[1] * img.RasterXSize/4,
                       old_geotransform[1],
                       old_geotransform[2],
                       old_geotransform[3] + old_geotransform[5] * img.RasterYSize/4,
                       old_geotransform[4],
                       old_geotransform[5])
    geoinfo['geotransform'] = new_geotransform
    img = None
    
    return geoinfo

In [None]:
def out_dataset_creation(temp_dir, result_filename, geoinfo, driver):
    out_filename = os.path.join(temp_dir, result_filename)
    out_raster = driver.Create(out_filename, geoinfo['YSize'], geoinfo['XSize'], 1, gdal.GDT_Byte)
    out_raster.SetGeoTransform(geoinfo['geotransform'])
    out_raster.SetProjection(geoinfo['projection'])
    
    return out_raster

In [None]:
def out_band_write(out_dataset, classification, geoinfo):
    classified = gdal.Open(classification)
    classified_arr = classified.GetRasterBand(1).ReadAsArray()
    
    x_low = int(geoinfo['XSize']/2)
    x_high = int(geoinfo['XSize']/2*3)
    y_low = int(geoinfo['YSize']/2)
    y_high = int(geoinfo['YSize']/2*3)
    
    out_arr = classified_arr[x_low:x_high, y_low:y_high]
    
    out_band = out_dataset.GetRasterBand(1)
    out_band.WriteArray(out_arr)
    out_band.SetNoDataValue(255)

In [None]:
def merge_rasters(rasters_dir, out_filepath, out_src):
    rasters_list = os.listdir(rasters_dir)
    for i, raster in enumerate(rasters_list):
        rasters_list[i] = os.path.join(rasters_dir, raster)
    
    vrt = gdal.BuildVRT(os.path.join(rasters_dir, 'temp_merged.vrt'), rasters_list)
    vrt.SetProjection(out_src)
    gdal.Translate(out_filepath, vrt)
    vrt = None

In [None]:
for image, result in zip(os.listdir(imagery_path),os.listdir(results_path)):
    ginfo = retrieve_geoinfo(os.path.join(imagery_path, image))
    out_ds = out_dataset_creation(temp_path, result, ginfo, out_driver)
    out_band_write(out_ds, os.path.join(results_path,result), ginfo)
    out_ds = None

merge_rasters(temp_path, final_filepath, ginfo['projection'])
shutil.rmtree(temp_path)