# Processing raster calculations (georeference parsing) 

In [8]:
from glob import glob
from osgeo import gdal
from osgeo import gdal_array
from osgeo import osr
import numpy as np
import matplotlib.pyplot as plt

The function below is useful to parse the TCI's projection and spatial extent information to the final product of our prediction process.

In [11]:

def georef_parsing(pred_layer, TCI_layer, output_fn, dtype):    
    # pred read
    preddataset = gdal.Open(pred_layer)
    predband1 = preddataset.GetRasterBand(1).ReadAsArray()
    predband2 = preddataset.GetRasterBand(2).ReadAsArray()
    predband3 = preddataset.GetRasterBand(3).ReadAsArray()
    p_image = np.dstack((predband1, predband2, predband3))
    # TCI read
    TCIdataset = gdal.Open(TCI_layer)
    
    # set data type to save.
    GDT_dtype = gdal.GDT_Unknown
    if dtype == "Byte": 
        GDT_dtype = gdal.GDT_Byte
    elif dtype == "Float32":
        GDT_dtype = gdal.GDT_Float32
    else:
        print("Not supported data type.")

    # create output
    driver = gdal.GetDriverByName('GTiff')
    cols = p_image.shape[0]
    rows = p_image.shape[1]
    band_num = p_image.shape[2]
    outRaster = driver.Create(output_fn, cols, rows, band_num, GDT_dtype)
    
    # Set spatial extent of TCI's to the output
    originX, pixelWidth, b, originY, d, pixelHeight = TCIdataset.GetGeoTransform() 
    outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))

    # Loop over all bands, write calculated array to output file
    for b in range(band_num):
        outband = outRaster.GetRasterBand(b + 1)
        outband.WriteArray(p_image[:,:,b])
    
    # setteing projection from TCI
    prj=TCIdataset.GetProjection()
    outRasterSRS = osr.SpatialReference(wkt=prj)
    outRaster.SetProjection(outRasterSRS.ExportToWkt())
    outband.FlushCache()
    


In [13]:
tci_tiffben = 'some_path.../sentinel/T34TGS_20220513T092031_TCI_10m_tiffben.tif'
pred_prod = 'some_path.../sentinel_concat.jpg'

georef_parsing(pred_prod, tci_tiffben, 'sentinel_concat_georeftry.tif',"Byte")

In [None]:
# For multiple files
# THIS NEEDS TO BE EDITED FOR THIS PROJECT, I USED IT FOR ANOTHER ONE...

all_mask = glob('some_path.../' + '*binary.tif')
all_TCI = glob('some_path.../' + '*.jp2')
names = [os.path.basename(x)[:6] for x in all_TCI]

for i in range(len(all_mask)):
    outname = 'TCI_masked_' + f'{names[i]}.tif'
    masking_process(all_mask[i], all_TCI[i], outname, "Byte")

print('Done')