In [3]:
%matplotlib inline
import numpy as np
import rasterio, glob, xarray as xr
import os,sys
import albumentations as A
from albumentations.core.transforms_interface import  ImageOnlyTransform
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
sys.path.append(r'/home/repos')
from torch.utils.data import DataLoader                                                                                 
from tfcl.models.ptavit3d.ptavit3d_dn import ptavit3d_dn       
from tfcl.nn.loss.ftnmt_loss import ftnmt_loss               
from tfcl.utils.classification_metric import Classification  
from datetime import datetime
import time

def getFilelist(originpath, ftyp):
    files = os.listdir(originpath)
    out   = []
    for i in files:
        if i.split('.')[-1] in ftyp:
            if originpath.endswith('/'):
                out.append(originpath + i)
            else:
                out.append(originpath + '/' + i)
        # else:
        #     print("non-matching file - {} - found".format(i.split('.')[-1]))
    return out

def export_np_to_tif(arr, src, path, name):
    with rasterio.open(
            path + name + '.tif',
            'w',
            crs=None,#src.crs,
            nodata=None, # change if data has nodata value
            transform=src.transform,
            driver='GTiff',
            height=arr.shape[1],
            width=arr.shape[2],
            count=arr.shape[0],
            dtype=arr.dtype
        ) as dst:
            for i in range(arr.shape[0]):
                dst.write(arr[i], i + 1)
# Normalization and transform functions

class AI4BNormal_S2(object):
    """
    class for Normalization of images, per channel, in format CHW 
    """
    def __init__(self):

        self._mean_s2 = np.array([5.4418573e+02, 7.6761194e+02, 7.1712860e+02, 2.8561428e+03 ]).astype(np.float32) 
        self._std_s2  = np.array( [3.7141626e+02, 3.8981952e+02, 4.7989127e+02 ,9.5173022e+02]).astype(np.float32) 

    def __call__(self,img):

        temp = img.astype(np.float32)
        temp2 = temp.T
        temp2 -= self._mean_s2
        temp2 /= self._std_s2

        temp = temp2.T
        return temp
    
class TrainingTransformS2(object):
    # Built on Albumentations, this provides geometric transformation only  
    def __init__(self,  prob = 1., mode='train', norm = AI4BNormal_S2() ):
        self.geom_trans = A.Compose([
                    A.RandomCrop(width=128, height=128, p=1.0),  # Always apply random crop
                    A.OneOf([
                        A.HorizontalFlip(p=1),
                        A.VerticalFlip(p=1),
                        A.ElasticTransform(p=1), # VERY GOOD - gives perspective projection, really nice and useful - VERY SLOW   
                        A.GridDistortion(distort_limit=0.4,p=1.),
                        A.ShiftScaleRotate(shift_limit=0.25, scale_limit=(0.75,1.25), rotate_limit=180, p=1.0), # Most important Augmentation   
                        ],p=1.)
                    ],
            additional_targets={'imageS1': 'image','mask':'mask'},
            p = prob)
        if mode=='train':
            self.mytransform = self.transform_train
        elif mode =='valid':
            self.mytransform = self.transform_valid
        else:
            raise ValueError('transform mode can only be train or valid')
            
            
        self.norm = norm
        
    def transform_valid(self, data):
        timgS2, tmask = data
        if self.norm is not None:
            timgS2 = self.norm(timgS2)
        
        tmask= tmask 
        return timgS2,  tmask.astype(np.float32)

    def transform_train(self, data):
        timgS2, tmask = data
        
        if self.norm is not None:
            timgS2 = self.norm(timgS2)

        tmask= tmask 
        tmask = tmask.astype(np.float32)
        # Special treatment of time series
        c2,t,h,w = timgS2.shape
        #print (c2,t,h,w)              
        timgS2 = timgS2.reshape(c2*t,h,w)
        result = self.geom_trans(image=timgS2.transpose([1,2,0]),
                                 mask=tmask.transpose([1,2,0]))
        timgS2_t = result['image']
        tmask_t  = result['mask']
        timgS2_t = timgS2_t.transpose([2,0,1])
        tmask_t = tmask_t.transpose([2,0,1])
        
        c2t,h2,w2 = timgS2_t.shape

        
        timgS2_t = timgS2_t.reshape(c2,t,h2,w2)
        return timgS2_t,  tmask_t
    def __call__(self, *data):
        return self.mytransform(data)

class VALIDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_data=r'//home/ai4boundaries/sentinel2/'):
        
        self.flnames_s2_img = sorted(glob.glob(os.path.join(path_to_data,r'images/' + 'LU' + '/*.nc')))
        self.flnames_s2_mask = sorted(glob.glob(os.path.join(path_to_data,r'masks/' + 'LU' + '/*.tif')))
        
        assert len(self.flnames_s2_img) == len(self.flnames_s2_mask), ValueError("Some problem, the masks and images are not in the same numbers, aborting")
        
        tlen = len(self.flnames_s2_img)
        
                                                                         
    # Helper function to read nc to raster 
    def ds2rstr(self,tname):
        variables2use=['B2','B3','B4','B8'] # ,'NDVI']
        ds = xr.open_dataset(tname)
        ds_np = np.concatenate([ds[var].values[None] for var in variables2use],0)

        return ds_np

    def read_mask(self,tname):
        return rasterio.open(tname).read((1,2,3))

    
    def __getitem__(self,idx):
        tname_img = self.flnames_s2_img[idx]
        tname_mask = self.flnames_s2_mask[idx]
        
        timg = self.ds2rstr(tname_img)
        tmask = self.read_mask(tname_mask)
        
            
        return timg, tmask
    
    def __len__(self):
        return len(self.flnames_s2_img)


local_rank = 0
# torch.cuda.set_device(local_rank)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
torch.manual_seed(0)

country = 'ES_nb_10epochs_7'
NClasses = 1
nf = 96
verbose = True
model_config = {'in_channels': 4,
                'spatial_size_init': (128, 128),
                'depths': [2, 2, 5, 2],
                'nfilters_init': nf,
                'nheads_start': nf // 4,
                'NClasses': NClasses,
                'verbose': verbose,
                'segm_act': 'sigmoid'}

modeli = ptavit3d_dn(**model_config).to(device)
modeli.load_state_dict(torch.load('/home/output/models/model_state_' + country + '.pth'))    

model = modeli.to(device) # Set model to gpu
model.eval()


vdata = VALIDataset()

preds = []

valid_loader = DataLoader(dataset=vdata, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

trg =  sorted(glob.glob(os.path.join('/home/ai4boundaries/sentinel2/masks/' + 'LU' + '/*.tif')))

for i, tupi in enumerate(valid_loader):
    image, label = tupi
    print(image.shape)
  
    with torch.no_grad():
        pred = model(image.cuda())
        preds.append(pred.detach().cpu().numpy())

    outDir = '/home/output/predictions/' + country
    os.makedirs(outDir, exist_ok=True)
    export_np_to_tif(preds[i][0,:,:,:], rasterio.open(trg[i]), outDir, '/' + country + '_pred_LU_' + trg[i].split('/')[-1].split('_')[1])



cuda:0
 @@@@@@@@@@@@@ Going DOWN @@@@@@@@@@@@@@@@@@@ 
depth:= 0, layer_dim_in: 96, layer_dim: 96, stage_depth::2, spatial_size::(32, 32), scales::[16, 8, 8]
depth:= 1, layer_dim_in: 96, layer_dim: 192, stage_depth::2, spatial_size::(16, 16), scales::[32, 4, 4]
depth:= 2, layer_dim_in: 192, layer_dim: 384, stage_depth::5, spatial_size::(8, 8), scales::[64, 2, 2]
depth:= 3, layer_dim_in: 384, layer_dim: 768, stage_depth::2, spatial_size::(4, 4), scales::[128, 1, 1]
 XXXXXXXXXXXXXXXXXXXXX Coming up XXXXXXXXXXXXXXXXXXXXXXXXX 
depth:= 4, layer_dim_in: 384, layer_dim: 384, stage_depth::5, spatial_size::(8, 8), scales::[64, 2, 2]
depth:= 5, layer_dim_in: 192, layer_dim: 192, stage_depth::2, spatial_size::(16, 16), scales::[32, 4, 4]
depth:= 6, layer_dim_in: 96, layer_dim: 96, stage_depth::2, spatial_size::(32, 32), scales::[16, 8, 8]
torch.Size([1, 4, 6, 256, 256])
torch.Size([1, 4, 6, 256, 256])
torch.Size([1, 4, 6, 256, 256])
torch.Size([1, 4, 6, 256, 256])
torch.Size([1, 4, 6, 256, 256])
t

In [None]:
print(len(preds))
print(len(preds[0]))
def export_np_to_tif(arr, src, path, name):
    with rasterio.open(
            path + name + '.tif',
            'w',
            crs=None,#src.crs,
            nodata=None, # change if data has nodata value
            transform=src.transform,
            driver='GTiff',
            height=arr.shape[1],
            width=arr.shape[2],
            count=arr.shape[0],
            dtype=arr.dtype
        ) as dst:
            for i in range(arr.shape[0]):
                dst.write(arr[i], i + 1)