In [None]:
import sys
dataFolder = '../../../../data/fields/'
sys.path.append(dataFolder)
from Auxiliary.helper import *

In [None]:
os.chdir('../../../../data/')
img_lab_path = 'fields/Auxiliary/vrt/Force_X_from_68_to_69_Y_from_42_to_42_Cube.vrt'
## define function to load vrt
def loadVRTintoNumpyAI4(vrtPath):
    ds = gdal.Open(vrtPath)
    bandNumber = ds.RasterCount
    bands = []
    for i in range(bandNumber):
        bands.append(ds.GetRasterBand(i+1).ReadAsArray())
    cube = np.dstack(bands)
    data_cube = np.transpose(cube, (2, 0, 1))
    reshaped_cube = data_cube.reshape(4, 6, 3000, 6000)
    return reshaped_cube

def getGeoTFandProj(vrtPath):
    ds = gdal.Open(vrtPath)
    return ds.GetGeoTransform(), ds.GetProjection()



In [None]:
dat = loadVRTintoNumpyAI4(img_lab_path)

chipsize = 128*4
overlap  = 20
rows, cols = dat.shape[2:]

row_start = [i for i in range(0, rows, chipsize - overlap)]
row_end = [i for i in range (chipsize, rows, chipsize - overlap)]
row_start = row_start[:len(row_end)] 

col_start = [i for i in range(0, cols, chipsize - overlap)]
col_end = [i for i in range (chipsize, cols, chipsize - overlap)] 
col_start = col_start[:len(col_end)]

In [None]:
# define the model (.pth) and assess loss curves
model_name = dataFolder + 'output/models/model_state_All_but_LU_transformed_42.pth'
model_name_short = model_name.split('/')[-1].split('.')[0]
local_rank = 0
# torch.cuda.set_device(local_rank)
# torch.manual_seed(0)

NClasses = 1
nf = 96
verbose = True
model_config = {'in_channels': 4,
                'spatial_size_init': (128, 128),
                'depths': [2, 2, 5, 2],
                'nfilters_init': nf,
                'nheads_start': nf // 4,
                'NClasses': NClasses,
                'verbose': verbose,
                'segm_act': 'sigmoid'}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    modeli = ptavit3d_dn(**model_config).to(device)
    modeli.load_state_dict(torch.load(model_name))
    model = modeli.to(device) # Set model to gpu
    model.eval()
    
preds = []

for i in range(len(row_end)):
    for j in range(len(col_end)):
    
        image = torch.tensor(dat[np.newaxis, :, :, row_start[i]:row_end[i], col_start[j]:col_end[j]])
        image = image.to(torch.float)
        image = image.to(device)  # Move image to the correct device
    
        with torch.no_grad():
            pred = model(image)
            preds.append(pred.detach().cpu().numpy())

In [None]:
outFolder = dataFolder + 'output/predictions/FORCE/chips/'
gtiff_driver = gdal.GetDriverByName('GTiff')
geoTF, geoPr = getGeoTFandProj(img_lab_path)
filenames = [f'X_{col_start[j]}_Y_{row_start[i]}.tif' for i in range(len(row_start)) for j in range(len(col_start))]

for i, file in enumerate(filenames):
    out_ds = gtiff_driver.Create(outFolder + file, int(chipsize - overlap), int(chipsize - overlap), 3, gdal.GDT_Float32)
    geotf = list(geoTF)
    geotf[0] = geotf[0] + geotf[1] * (int(file.split('X_')[-1].split('_')[0]) + overlap/2)
    geotf[3] = geotf[3] + geotf[5] * (int(file.split('Y_')[-1].split('.')[0]) + overlap/2)
    #print(f'X:{geoTF[0]}  Y:{geoTF[3]}  AT {file}')
    out_ds.SetGeoTransform(tuple(geotf))
    out_ds.SetProjection(geoPr)

    arr = preds[i][0].transpose(1, 2, 0)
    for band in range(3):
        out_ds.GetRasterBand(band + 1).WriteArray(arr[int(overlap/2): -int(overlap/2), int(overlap/2): -int(overlap/2), band])
    del out_ds

['fields/output/predictions/FORCE/chips/X_492_Y_1476.tif',
 'fields/output/predictions/FORCE/chips/X_4920_Y_2460.tif',
 'fields/output/predictions/FORCE/chips/X_4428_Y_984.tif',
 'fields/output/predictions/FORCE/chips/X_1476_Y_2460.tif',
 'fields/output/predictions/FORCE/chips/X_5412_Y_984.tif',
 'fields/output/predictions/FORCE/chips/X_984_Y_492.tif',
 'fields/output/predictions/FORCE/chips/X_0_Y_0.tif',
 'fields/output/predictions/FORCE/chips/X_3444_Y_0.tif',
 'fields/output/predictions/FORCE/chips/X_1968_Y_1968.tif',
 'fields/output/predictions/FORCE/chips/X_2952_Y_984.tif',
 'fields/output/predictions/FORCE/chips/X_2460_Y_984.tif',
 'fields/output/predictions/FORCE/chips/X_984_Y_0.tif',
 'fields/output/predictions/FORCE/chips/X_1476_Y_1968.tif',
 'fields/output/predictions/FORCE/chips/X_1968_Y_984.tif',
 'fields/output/predictions/FORCE/chips/X_984_Y_1968.tif',
 'fields/output/predictions/FORCE/chips/X_5412_Y_2460.tif',
 'fields/output/predictions/FORCE/chips/X_2460_Y_0.tif',
 'fie