In [1]:
import sys
dataFolder = '../../../../../data/fields/'
sys.path.append('../')
from Auxiliary.helper import *

INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.


In [2]:
os.chdir('../../../../../data/')
img_folder_path = 'fields/Auxiliary/vrt/Force_X_from_64_to_73_Y_from_39_to_47/'
## define function to load vrt
def loadVRTintoNumpyAI4(vrtPath):

    vrtFiles = [file for file in getFilelist(vrtPath, '.vrt') if 'Cube' not in file]
    vrtFiles = sortListwithOtherlist([int(vrt.split('_')[-1].split('.')[0]) for vrt in vrtFiles], vrtFiles)[-1]
    bands = []

    for vrt in vrtFiles:
        ds = gdal.Open(vrt)
        bands.append(ds.GetRasterBand(1).ReadAsArray())
    cube = np.dstack(bands)
    data_cube = np.transpose(cube, (2, 0, 1))
    reshaped_cube = data_cube.reshape(4, 6, ds.RasterYSize, ds.RasterXSize)
    normalizer = AI4BNormal_S2()
    return normalizer(reshaped_cube)
    
    # ds = gdal.Open(vrtPath)
    # bandNumber = ds.RasterCount
    # bands = []
    # for i in range(bandNumber):
    #     bands.append(ds.GetRasterBand(i+1).ReadAsArray())
    # cube = np.dstack(bands)
    # data_cube = np.transpose(cube, (2, 0, 1))
    # reshaped_cube = data_cube.reshape(4, 6, ds.RasterYSize, ds.RasterXSize)
    # return reshaped_cube

def getGeoTFandProj(vrtPath):
    ds = gdal.Open(vrtPath)
    return ds.GetGeoTransform(), ds.GetProjection()

# load Force vrt into numpy array
dat = loadVRTintoNumpyAI4(img_folder_path)


KeyboardInterrupt: 

In [None]:
chipsize = 128*1 # 5 is the maximum with GPU in basement
overlap  = 20
rows, cols = dat.shape[2:]

row_start = [i for i in range(0, rows, chipsize - overlap)]
row_end = [i for i in range (chipsize, rows, chipsize - overlap)]
row_start = row_start[:len(row_end)] 

col_start = [i for i in range(0, cols, chipsize - overlap)]
col_end = [i for i in range (chipsize, cols, chipsize - overlap)] 
col_start = col_start[:len(col_end)]

In [None]:
# define the model (.pth) and assess loss curves
model_name = dataFolder + 'output/models/model_state_All_but_LU_transformed_42.pth'
model_name_short = model_name.split('/')[-1].split('.')[0]
local_rank = 0
# torch.cuda.set_device(local_rank)
# torch.manual_seed(0)

NClasses = 1
nf = 96
verbose = True
model_config = {'in_channels': 4,
                'spatial_size_init': (128, 128),
                'depths': [2, 2, 5, 2],
                'nfilters_init': nf,
                'nheads_start': nf // 4,
                'NClasses': NClasses,
                'verbose': verbose,
                'segm_act': 'sigmoid'}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    modeli = ptavit3d_dn(**model_config).to(device)
    modeli.load_state_dict(torch.load(model_name))
    model = modeli.to(device) # Set model to gpu
    model.eval()
    
preds = []

for i in range(len(row_end)):
    for j in range(len(col_end)):
    
        image = torch.tensor(dat[np.newaxis, :, :, row_start[i]:row_end[i], col_start[j]:col_end[j]])
        image = image.to(torch.float)
        image = image.to(device)  # Move image to the correct device
    
        with torch.no_grad():
            pred = model(image)
            preds.append(pred.detach().cpu().numpy())
            
torch.cuda.empty_cache()
del model
del modeli
del device
del image

In [None]:
outFolder = dataFolder + 'output/predictions/FORCE/BRANDENBURG/'
gtiff_driver = gdal.GetDriverByName('GTiff')
vrts = getFilelist(img_folder_path, '.vrt')
geoTF, geoPr = getGeoTFandProj(vrts[0])
filenames = [f'X_{col_start[j]}_Y_{row_start[i]}.tif' for i in range(len(row_start)) for j in range(len(col_start))]

# load mask
ds = gdal.Open(dataFolder + 'IACS/Auxiliary/GSA-DE_BRB-2019_All_agromask.tif')
mask = ds.GetRasterBand(1).ReadAsArray()

for i, file in enumerate(filenames):
    for j in ['chips/', 'masked_chips/']:
        out_ds = gtiff_driver.Create(f'{outFolder}{j}{str(chipsize)}_{overlap}_{file}', int(chipsize - overlap), int(chipsize - overlap), 3, gdal.GDT_Float32)
        # change the Geotransform for each chip
        geotf = list(geoTF)
        # get column and rows from filenames
        geotf[0] = geotf[0] + geotf[1] * (int(file.split('X_')[-1].split('_')[0]) + overlap/2)
        geotf[3] = geotf[3] + geotf[5] * (int(file.split('Y_')[-1].split('.')[0]) + overlap/2)
        #print(f'X:{geoTF[0]}  Y:{geoTF[3]}  AT {file}')
        out_ds.SetGeoTransform(tuple(geotf))
        out_ds.SetProjection(geoPr)

        arr = preds[i][0].transpose(1, 2, 0)
        if j == 'masked_chips/':
            maskSub = mask[int(int(file.split('Y_')[-1].split('.')[0]) + overlap/2):chipsize + int(int(file.split('Y_')[-1].split('.')[0]) - overlap/2), 
                           int(int(file.split('X_')[-1].split('_')[0]) + overlap/2):chipsize + int(int(file.split('X_')[-1].split('_')[0]) - overlap/2)]
            for band in range(3):                
                out_ds.GetRasterBand(band + 1).WriteArray(arr[int(overlap/2): -int(overlap/2), int(overlap/2): -int(overlap/2), band] * maskSub)
            del out_ds
        else:
            for band in range(3):
                out_ds.GetRasterBand(band + 1).WriteArray(arr[int(overlap/2): -int(overlap/2), int(overlap/2): -int(overlap/2), band])
            del out_ds


In [None]:
def getExtentRas(raster):
    if type(raster) is str:
        ds = gdal.Open(raster)
    elif type(raster) is gdal.Dataset:
        ds = raster
    gt = ds.GetGeoTransform()
    ext = {'Xmin': gt[0],
            'Xmax': gt[0] + (gt[1] * ds.RasterXSize),
            'Ymin': gt[3] + (gt[5] * ds.RasterYSize),
            'Ymax': gt[3]}
    return ext

def commonBoundsDim(extentList):
    # create empty dictionary with list slots for corner coordinates
    k = ['Xmin', 'Xmax', 'Ymin', 'Ymax']
    v = [[], [], [], []]
    res = dict(zip(k, v))

    # fill it with values of all raster files
    for i in extentList:
        for j in k:
            res[j].append(i[j])
    # determine min or max values per values' list to get common bounding box
    ff = [max, min, max, min]
    for i, j in enumerate(ff):
        res[k[i]] = j(res[k[i]])
    return res

def commonBoundsCoord(ext):
    if type(ext) is dict:
        ext = [ext]
    else:
        ext = ext
    cooL = []
    for i in ext:
        coo = {'UpperLeftXY': [i['Xmin'], i['Ymax']],
               'UpperRightXY': [i['Xmax'], i['Ymax']],
               'LowerRightXY': [i['Xmax'], i['Ymin']],
               'LowerLeftXY': [i['Xmin'], i['Ymin']]}
        cooL.append(coo)
    return cooL

In [33]:
# check if mask has different extent from prediction
# if so, make it the same extent for further processing (classification)
# --> mask can never be smaller than prediciton, therefore no need to check
mask_path = 'IACS/Auxiliary/GSA-DE_BRB-2019_All_agromask_linecrop.tif'
ext_mask = getExtentRas(dataFolder + mask_path)
ext_pred = getExtentRas(dataFolder + 'output/predictions/FORCE/BRANDENBURG/vrt/256_20_masked_chipsvrt.vrt')

common_bounds = commonBoundsDim([ext_mask, ext_pred])
common_coords = commonBoundsCoord(common_bounds)
if common_bounds == ext_pred:
    ds = gdal.Open(dataFolder + mask_path)
    in_gt = ds.GetGeoTransform()
    inv_gt = gdal.InvGeoTransform(in_gt)
    # transform coordinates into offsets (in cells) and make them integer
    off_UpperLeft = gdal.ApplyGeoTransform(inv_gt, common_coords[0]['UpperLeftXY'][0], common_coords[0]['UpperLeftXY'][1])  # new UL * rastersize^-1  + original ul/rastersize(opposite sign
    off_LowerRight = gdal.ApplyGeoTransform(inv_gt, common_coords[0]['LowerRightXY'][0], common_coords[0]['LowerRightXY'][1])
    off_ULx, off_ULy = map(round, off_UpperLeft) 
    off_LRx, off_LRy = map(round, off_LowerRight)

    band = ds.GetRasterBand(1)
    data = band.ReadAsArray(off_ULx, off_ULy, off_LRx - off_ULx, off_LRy - off_ULy)


    out_ds = gdal.GetDriverByName('GTiff').Create(dataFolder + mask_path.split('.')[0] + '_prediction_extent.tif', 
                                                  off_LRx - off_ULx, 
                                                  off_LRy - off_ULy, 1, ds.GetRasterBand(1).DataType)
    out_gt = list(in_gt)
    out_gt[0], out_gt[3] = gdal.ApplyGeoTransform(in_gt, off_ULx, off_ULy)
    out_ds.SetGeoTransform(out_gt)
    out_ds.SetProjection(ds.GetProjection())

    out_ds.GetRasterBand(1).WriteArray(data)
    if band.GetNoDataValue():
        out_ds.GetRasterBand(1).SetNoDataValue(band.GetNoDataValue())
    del out_ds
    