### Download the model from https://www.dropbox.com/scl/fi/srw7u4cw1gtxrf4xzmsh7/floodvit.pt?rlkey=snskpq1qrdav5u2jya8k2bocg&e=1&dl=0

### Create a modified version of the vit model and upload it in cloud (Google Drive, Dropbox, Zenodo)
### Vit pipeline should be differentiated with respect to statistical approach.

In [13]:
import xarray as xr
import rioxarray
import torch
import numpy as np
from torchvision import transforms
import xbatcher
from tqdm import tqdm

In [14]:
vit_model = torch.load('/home/kleanthis/Projects/Thessalia_Floods_2023/Vit_model/floodvit.pt')
S1_dataset = xr.open_dataset('/home/kleanthis/Projects/Thessalia_Floods_2023/Preprocessed_20230906T043947/S1_stack_20230906T043947.nc', decode_coords='all')

In [15]:
device = 'cuda'
vit_model.to(device)

batch_size = 224
half_batch_size = int(batch_size/2)

data_mean =  [0.0953, 0.0264]
data_std = [0.0427, 0.0215]
clamp_input = 0.15
Normalize = transforms.Normalize(mean=data_mean, std=data_std)

In [16]:
# create a dataset moved by half batch size
S1_dataset_moved_A = S1_dataset.sel(x=slice(S1_dataset.x.isel(x=56).data,S1_dataset.x.isel(x=-1).data),
                                 y=slice(S1_dataset.y.isel(y=56).data,S1_dataset.y.isel(y=-1).data))

S1_dataset_moved_B = S1_dataset.sel(x=slice(S1_dataset.x.isel(x=112).data,S1_dataset.x.isel(x=-1).data),
                                 y=slice(S1_dataset.y.isel(y=112).data,S1_dataset.y.isel(y=-1).data))

S1_dataset_moved_C = S1_dataset.sel(x=slice(S1_dataset.x.isel(x=168).data,S1_dataset.x.isel(x=-1).data),
                                 y=slice(S1_dataset.y.isel(y=168).data,S1_dataset.y.isel(y=-1).data))

In [17]:
prediction_xarrays_list = []
for xr_dataset in [S1_dataset, S1_dataset_moved_A, S1_dataset_moved_B, S1_dataset_moved_C]:
    predictions_batches_list = []
    post_bgen = xbatcher.BatchGenerator(xr_dataset.sel(time='2023-09-06T04:39:47.000000000'), input_dims = {'x': batch_size, 'y': batch_size})
    pre1_bgen = xbatcher.BatchGenerator(xr_dataset.sel(time='2023-08-01T04:39:57.000000000'), input_dims = {'x': batch_size, 'y': batch_size})
    pre2_bgen = xbatcher.BatchGenerator(xr_dataset.sel(time='2023-08-13T04:39:57.000000000'), input_dims = {'x': batch_size, 'y': batch_size})

    for patch_i in tqdm(range(len(post_bgen))):

        post_dB = np.stack([post_bgen[patch_i].VV_dB.values, post_bgen[patch_i].VH_dB.values],axis=0)
        post = np.power(10, post_dB/10) # convert to linear

        pre1_dB = np.stack([pre1_bgen[patch_i].VV_dB.values, pre1_bgen[patch_i].VH_dB.values],axis=0)
        pre1 = np.power(10, pre1_dB/10) # convert to linear

        pre2_dB = np.stack([pre2_bgen[patch_i].VV_dB.values, pre2_bgen[patch_i].VH_dB.values],axis=0)
        pre2 = np.power(10, pre2_dB/10) # convert to linear

        post = torch.clamp(torch.from_numpy(post).float(), min=0.0, max=clamp_input)
        post = torch.nan_to_num(post,clamp_input)
        pre1 = torch.clamp(torch.from_numpy(pre1).float(), min=0.0, max=clamp_input)
        pre1 = torch.nan_to_num(pre1,clamp_input)
        pre2 = torch.clamp(torch.from_numpy(pre2).float(), min=0.0, max=clamp_input)
        pre2 = torch.nan_to_num(pre2,clamp_input)

        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad():
                post_event = Normalize(post).to(device).unsqueeze(0)
                pre_event_1 = Normalize(pre1).to(device).unsqueeze(0)
                pre_event_2 = Normalize(pre2).to(device).unsqueeze(0)

                pre_event_1 = pre_event_1.to(device)
                post_event = torch.cat((post_event, pre_event_1), dim=1)
                post_event = torch.cat((post_event, pre_event_2.to(device)), dim=1)
                output = vit_model(post_event)

                predictions = output.argmax(1)

        prediction_data = np.squeeze(predictions.to('cpu').numpy())

        prediction_patch_xarray = xr.Dataset({'flood_vit': (["y","x"], prediction_data)},
                                            coords={
                                                    "x": (["x"], post_bgen[patch_i].x.data),
                                                    "y": (["y"], post_bgen[patch_i].y.data),
                                            },
                                            )
        prediction_patch_xarray.rio.write_crs("epsg:4326", inplace=True)
        predictions_batches_list.append(prediction_patch_xarray)

    prediction_xarrays_list.append(xr.combine_by_coords(predictions_batches_list))

# vit_flooded_regions = xr.merge(prediction_xarrays_list, compat='override')
# vit_flooded_regions.x.attrs['standard_name'] = 'X'
# vit_flooded_regions.x.attrs['long_name'] = 'Coordinate X'
# vit_flooded_regions.x.attrs['units'] = 'degrees'
# vit_flooded_regions.x.attrs['axis'] = 'X'

# vit_flooded_regions.y.attrs['standard_name'] = 'Y'
# vit_flooded_regions.y.attrs['long_name'] = 'Coordinate Y'
# vit_flooded_regions.y.attrs['units'] = 'degrees'
# vit_flooded_regions.y.attrs['axis'] = 'Y'
# vit_flooded_regions.rio.write_crs("epsg:4326", inplace=True)

# vit_flooded_regions.to_netcdf('/home/kleanthis/Projects/Thessalia_Floods_2023/Results/Flood_vit.nc', format='NETCDF4')

100%|██████████| 400/400 [00:23<00:00, 17.25it/s]
100%|██████████| 375/375 [00:22<00:00, 17.01it/s]
100%|██████████| 375/375 [00:21<00:00, 17.55it/s]
100%|██████████| 360/360 [00:21<00:00, 17.08it/s]


In [18]:
for i, pixel_disp in enumerate([0,56,112, 168]):
    vit_flooded_regions = prediction_xarrays_list[i]
    vit_flooded_regions.x.attrs['standard_name'] = 'X'
    vit_flooded_regions.x.attrs['long_name'] = 'Coordinate X'
    vit_flooded_regions.x.attrs['units'] = 'degrees'
    vit_flooded_regions.x.attrs['axis'] = 'X'

    vit_flooded_regions.y.attrs['standard_name'] = 'Y'
    vit_flooded_regions.y.attrs['long_name'] = 'Coordinate Y'
    vit_flooded_regions.y.attrs['units'] = 'degrees'
    vit_flooded_regions.y.attrs['axis'] = 'Y'
    vit_flooded_regions.rio.write_crs("epsg:4326", inplace=True)

    vit_flooded_regions.to_netcdf('/home/kleanthis/Projects/Thessalia_Floods_2023/Results/Flood_vit_{}.nc'.format(pixel_disp), format='NETCDF4')