In [None]:
%matplotlib inline
import numpy as np
import rasterio, glob, xarray as xr
import os,sys
import albumentations as A
from albumentations.core.transforms_interface import  ImageOnlyTransform
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
sys.path.append(r'/home/repos')
from torch.utils.data import DataLoader                                                                                 
from tfcl.models.ptavit3d.ptavit3d_dn import ptavit3d_dn       
from tfcl.nn.loss.ftnmt_loss import ftnmt_loss               
from tfcl.utils.classification_metric import Classification  
from datetime import datetime   
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import pandas as pd
import random
from rocksdbutils_copy import * 
from math import ceil as mceil
import time

# Set this to False for training
#DEBUG=True
DEBUG=False


# Normalization and transform functions

class AI4BNormal_S2(object):
    """
    class for Normalization of images, per channel, in format CHW 
    """
    def __init__(self):

        self._mean_s2 = np.array([5.4418573e+02, 7.6761194e+02, 7.1712860e+02, 2.8561428e+03 ]).astype(np.float32) 
        self._std_s2  = np.array( [3.7141626e+02, 3.8981952e+02, 4.7989127e+02 ,9.5173022e+02]).astype(np.float32) 

    def __call__(self,img):

        temp = img.astype(np.float32)
        temp2 = temp.T
        temp2 -= self._mean_s2
        temp2 /= self._std_s2

        temp = temp2.T
        return temp

class TrainingTransform_for_rocks_Train(object):
    # Built on Albumentations, this provides geometric transformation only  
    def __init__(self,  prob = 1, norm = AI4BNormal_S2()):
        self.geom_trans = A.Compose([
                  #  A.OneOf([
                        A.HorizontalFlip(p=1),
             #           A.VerticalFlip(p=1),
            #            A.ElasticTransform(p=1), # VERY GOOD - gives perspective projection, really nice and useful - VERY SLOW   
           #             A.GridDistortion(distort_limit=0.4,p=1),
           #             A.ShiftScaleRotate(shift_limit=0.25, scale_limit=(0.75,1.25), rotate_limit=180, p=1.0), # Most important Augmentation   
                #        ],p=1.)
                    ],
            #additional_targets={'imageS1': 'image','mask':'mask'},
            p = prob)
      
        self.mytransform = self.transform_train
        self.norm = norm
        
    def transform_valid(self, data):
        timgS2, tmask = data
        if self.norm is not None:
            timgS2 = self.norm(timgS2)
        
        tmask= tmask 
        return timgS2,  tmask.astype(np.float32)

    def transform_train(self, data):
        timgS2, tmask = data
        
        if self.norm is not None:
            timgS2 = self.norm(timgS2)

        tmask= tmask 
        tmask = tmask.astype(np.float32)
        # Special treatment of time series
        c2,t,h,w = timgS2.shape
        #print (c2,t,h,w)              
        timgS2 = timgS2.reshape(c2*t,h,w)
        result = self.geom_trans(image=timgS2.transpose([1,2,0]),
                                 mask=tmask.transpose([1,2,0]))
        timgS2_t = result['image']
        tmask_t  = result['mask']
        timgS2_t = timgS2_t.transpose([2,0,1])
        tmask_t = tmask_t.transpose([2,0,1])
        
        c2t,h2,w2 = timgS2_t.shape

        
        timgS2_t = timgS2_t.reshape(c2,t,h2,w2)
        return timgS2_t,  tmask_t
    def __call__(self, *data):
        return self.mytransform(data)

class TrainingTransformS2(object):
    # Built on Albumentations, this provides geometric transformation only  
    def __init__(self,  prob = 1., mode='train', norm = AI4BNormal_S2()):
        self.geom_trans = A.Compose([
                    # A.RandomCrop(width=128, height=128, p=1.0),  # Always apply random crop
                    # A.OneOf([
                    #     A.HorizontalFlip(p=1),
                    #     A.VerticalFlip(p=1),
                    #     A.ElasticTransform(p=1), # VERY GOOD - gives perspective projection, really nice and useful - VERY SLOW   
                    #     A.GridDistortion(distort_limit=0.4,p=1.),
                    #     A.ShiftScaleRotate(shift_limit=0.25, scale_limit=(0.75,1.25), rotate_limit=180, p=1.0), # Most important Augmentation   
                    #     ],p=1.)
                    A.HorizontalFlip(p=1),
                    A.VerticalFlip(p=1),
                    A.ElasticTransform(p=1), # VERY GOOD - gives perspective projection, really nice and useful - VERY SLOW   
                    A.GridDistortion(distort_limit=0.4,p=1.),
                    A.ShiftScaleRotate(shift_limit=0.25, scale_limit=(0.75,1.25), rotate_limit=180, p=1.0), # Most important Augmentation   
                    ],
            additional_targets={'imageS1': 'image','mask':'mask'},
            p = prob)
        if mode=='train':
            self.mytransform = self.transform_train
        elif mode =='valid':
            self.mytransform = self.transform_valid
        else:
            raise ValueError('transform mode can only be train or valid')
            
            
        self.norm = norm
        
    def transform_valid(self, data):
        timgS2, tmask = data
        if self.norm is not None:
            timgS2 = self.norm(timgS2)
        
        tmask= tmask 
        return timgS2,  tmask.astype(np.float32)

    def transform_train(self, data):
        timgS2, tmask = data
        
        if self.norm is not None:
            timgS2 = self.norm(timgS2)

        tmask= tmask 
        tmask = tmask.astype(np.float32)
        # Special treatment of time series
        c2,t,h,w = timgS2.shape
        #print (c2,t,h,w)              
        timgS2 = timgS2.reshape(c2*t,h,w)
        result = self.geom_trans(image=timgS2.transpose([1,2,0]),
                                 mask=tmask.transpose([1,2,0]))
        timgS2_t = result['image']
        tmask_t  = result['mask']
        timgS2_t = timgS2_t.transpose([2,0,1])
        tmask_t = tmask_t.transpose([2,0,1])
        
        c2t,h2,w2 = timgS2_t.shape

        
        timgS2_t = timgS2_t.reshape(c2,t,h2,w2)
        return timgS2_t,  tmask_t
    def __call__(self, *data):
        return self.mytransform(data)

# create a function to plot a numpy array
def plotter(array):

    # Plot the slices
    fig, axes = plt.subplots(1, 4, figsize=(20, 6), constrained_layout=False)  # 4 slices
    slice_indices = np.linspace(0, array.shape[0] - 1, 4, dtype=int)

    # Create a colormap
    cmap = plt.cm.viridis

    for ax, idx in zip(axes, slice_indices):
        im = ax.imshow(array[idx, :, :], cmap=cmap)
        ax.set_title(f"Slice {idx}")
        ax.set_xticks([0, 32, 64, 96, 127])
        ax.set_yticks([0, 32, 64, 96, 127])
        ax.set_xticklabels(['X0', 'X32', 'X64', 'X96', 'X127'])
        ax.set_yticklabels(['Y0', 'Y32', 'Y64', 'Y96', 'Y127'])

        cbar_ax = ax.inset_axes([0.1, -0.2, 0.8, 0.05])  # [x, y, width, height]
        cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('Value Scale')

    plt.tight_layout()
    plt.show()

def getFilelist(originpath, ftyp):
    files = os.listdir(originpath)
    out   = []
    for i in files:
        if i.split('.')[-1] in ftyp:
            if originpath.endswith('/'):
                out.append(originpath + i)
            else:
                out.append(originpath + '/' + i)
        # else:
        #     print("non-matching file - {} - found".format(i.split('.')[-1]))
    return out


In [2]:
# load a rocksdb dataset
random.seed(42)
country = 'ES_no_empty_label'
train_dataset = RocksDBDataset(f'/home/output/rocks_db/{country}.db/train.db')
valid_dataset = RocksDBDataset(f'/home/output/rocks_db/{country}.db/valid.db') # when indexing this class no transform/normalisation will be performed unless specified else

In [3]:
len(valid_dataset)

# aa = DataLoader(dataset=train_dataset, batch_size=2,
#                               shuffle=False, num_workers=4, pin_memory=True)

# for i, e in aa:
#     print(i.shape)
#     print(e.shape)
  


1210

In [None]:
img_no = 450
# 450

a = train_dataset[img_no][0]
b = train_dataset[img_no][1]

# apply transorm (normalization and augmentation)
c, d = TrainingTransformS2(mode='train')(a, b)

print(c.shape, d.shape)
year = 5
plotter(a[:, year, :, :])
plotter(c[:, year, :, :])
#plotter(a[:, year, :, :] - c[:,year,:,:])
plotter(b)
plotter(d)
#plotter(b - d)

In [3]:
# see if there were empty label images from the beginning
label_List, label_index_List = [], []
for i, tupi in enumerate(train_dataset):
    image = tupi[0]
    label = tupi[1]

    lc = np.unique(label)
    
    if len(lc) == 1:
        label_List.append(lc)
        label_index_List.append(i)


# assess the label destruction trough the augmentation
image_transformed_List, label_transformed_List = [], []
image_transformed_index_List, label_transformed_index_List = [], []

for i, tupi in enumerate(train_dataset):
    image = tupi[0]
    label = tupi[1]
    
    image_transformed, label_transformed = TrainingTransform_for_rocks_Train()(image, label)
    #  ic = np.unique(image_transformed)
    lc = np.unique(label_transformed)
    # if len(ic) == 1:
    #     image_transformed_List.append(ic)
    #     image_transformed_index_List.append(i)
    if len(lc) == 1:
        label_transformed_List.append(lc)
        label_transformed_index_List.append(i)

# check the indices of labels that are 0 in the beginning and after transformation
overlap = np.intersect1d(label_index_List, label_transformed_index_List)

In [None]:
print(len(train_dataset))
print(len(valid_dataset))


In [5]:
# load the rocksdb with 64x64 image chips for comparing train/valid ratio 
train_dataset64 = RocksDBDataset('/home/output/rocks_db/ES_64img.db/train.db')
valid_dataset64 = RocksDBDataset('/home/output/rocks_db/ES_64img.db/valid.db')
label_transformed_List64 = []
label_transformed_index_List64 = []

for i, tupi in enumerate(train_dataset64):
    image = tupi[0]
    label = tupi[1]
    
    image_transformed, label_transformed = TrainingTransformS2(mode='train')(image, label)
    #  ic = np.unique(image_transformed)
    lc = np.unique(label_transformed)
    # if len(ic) == 1:
    #     image_transformed_List.append(ic)
    #     image_transformed_index_List.append(i)
    if len(lc) == 1:
        label_transformed_List64.append(lc)
        label_transformed_index_List64.append(i)

        
label_List64, label_index_List64 = [], []
for i, tupi in enumerate(train_dataset64):
    image = tupi[0]
    label = tupi[1]

    lc = np.unique(label)
    
    if len(lc) == 1:
        label_List64.append(lc)
        label_index_List64.append(i)


In [None]:
country = 'ES'
lenai4 = len(getFilelist(f'/home/ai4boundaries/sentinel2/masks/{country}/', 'tif')) -47
lentrain = len(train_dataset) / lenai4
lenvalid = len(valid_dataset) / lenai4

print(f"Exploration of results of the {country} database\
      \nThe amount of images in ai4boundaries for '{country}' is {lenai4} images\
      \n\nThe parameters for creating rocks.db= stride_divisor = 2, batch_size = 2, train_split = 0.9\
      \nWith this settings and a descrease in resoltuion from 256x256 to 128x128:\
      \n--> one image from ai4boundaries is cutted into {int(lentrain + lenvalid)} chips; {int(lentrain)} go into \
      \ntraining and {int(lenvalid)} into validation rocks database")

print(f"\nEmpty labels from trainings-dataset(n={len(train_dataset)}, for image chips 128x128):\
      \nBefore transformation: {len(label_List)}\nAfter transformation: {len(label_transformed_List)}\
      \n\n{(len(overlap)/len(label_List))*100}% of empty labels before transformation still empty after transformation\
      \nThe ratio of validation to training in terms of dataset sizes is {round(len(valid_dataset)/len(train_dataset),2)}")


In [None]:
overlap64 = np.intersect1d(label_index_List64, label_transformed_index_List64)
lenai64 = len(getFilelist('/home/ai4boundaries/sentinel2/masks/ES/', 'tif'))
lentrain64 = len(train_dataset64) / lenai64
lenvalid64 = len(valid_dataset64) / lenai64
# qq = [lab for lab in label_index_List64 if lab in label_transformed_index_List64]
# print(len(qq), len(label_index_List64))

In [None]:
print("\n--------------------------------------------------------------------------------------------")
print(f"\nIf image is split into 64x64 chips:\nEmpty labels from trainings-dataset(n={len(train_dataset64)})\
      \nBefore transformation: {len(label_List64)}\nAfter transformation: {len(label_transformed_List64)}\
      \n\n{round(len(overlap64)/len(label_List64),2)*100}% of empty labels before transformation still empty after transformation\
      \nThe ratio of validation to training in terms of dataset sizes is {round(len(valid_dataset64)/len(train_dataset64),2)}\
      \n\nWith this settings from above and a descrease in resoltuion from 256x256 to 64x64:\
      \n--> one image from ai4boundaries is cutted into {int(lentrain64 + lenvalid64)} chips; {int(lentrain64)} go into \
      \ntraining and {int(lenvalid64)} into validation rocks database")


In [None]:
print(f"Speed considerations:\nThe example notebook that does not use rocksdb needs ~ 165 s/it in an epoch with the 'ES' dataset\
      \nUsing rocksdb for the same dataset results in ~ 927 s/it\
      \nThis means, it needs {round(927/165,2)} times longer to use rocksdb.\
      \nHowever, the notebook only takes one 128x128 image chip for training and one for validation, rocksdb uses 4 for training and ?4? for validation\
      \nThis means, the speed might be the same, if not a bit faster for rocksdb, if all validation images were used")

In [None]:
# search actual nc files for completely empty labels
counter = 0
mask_index0 = []
masks = getFilelist(f'/home/ai4boundaries/sentinel2/masks/{country}/', 'tif')
masks.sort()

for i, img in enumerate(masks):
        with rasterio.open(img) as src:
            src_array = src.read()
            if len(np.unique(src_array))==1:
                counter += 1
                mask_index0.append(i)
                

print(f"\n{counter} images from ai4boundaries have empty labels")


In [None]:
f"However, as in the tif labels are only {counter} empty, there should be only {counter*4} empty labels in the training rocksdb dataset\
    \nas there are 4 chips per image in this section of the dabase.\
    \nLet's look into that:"
# indices in trainings rocksdb dataset that should be empty
should_be_empty = list(np.array([[i*4, i*4+1, i*4+2, i*4+3] for i in mask_index0]).flat)
for i in should_be_empty:
    if len(np.unique(train_dataset[i][1])) != 1:
        print(f"Empty raw label through insertion into rocks now falsly all 0 at index {i}")
# test which other indices in rocksdb are now empty
add0 = [i for i in label_index_List if i not in should_be_empty]
print(len(add0))
if len(label_index_List) - len(add0) != len(should_be_empty):
    print("Something is wrong")

In [None]:
# convert the indices from the additional zero labels to the indices in the ai4boundaries dataset
raw_ind = [i//4 for i in add0]
raw_ind

In [None]:
plotter(train_dataset[3][0][:,1,:,:])

In [None]:
# test the generate slice function

def slicetester(shape, Fi, s):
        
        # Constants that relate to rows, columns 
        nTimesRows = int((shape[-2] - Fi)//s + 1)
        nTimesCols = int((shape[-1] - Fi)//s + 1)

        # Use these directly 
        RowsCols = [(row, col) for row in range(nTimesRows-1) for col in range(nTimesCols-1)]
        RowsCols_Slices = [ (slice(row*s,row*s +Fi,1),slice(col*s,col*s+Fi,1) )  for (row,col) in RowsCols ]
        #print(RowsCols)
        #print(RowsCols_Slices)
        # Construct RowsCols for last Col 
        col_rev = shape[-1]-Fi
        Rows4LastCol = [(row,col_rev) for row in range(nTimesRows-1)]
        Rows4LastCol_Slices = [ (slice(row*s,row*s +Fi,1),slice(col_rev,col_rev+Fi,1) )  for (row,col_rev) in Rows4LastCol]

        # Construct RowsCols for last Row 
        row_rev = shape[-2]-Fi
        Cols4LastRow        = [(row_rev,col) for col in range(nTimesCols-1)]
        Cols4LastRow_Slices = [(slice(row_rev,row_rev+Fi,1),slice(col*s,col*s +Fi,1) )  for (row_rev,col) in Cols4LastRow]

        
        # Store all Rows and Columns that correspond to raster slices and slices 
        RowsCols           = RowsCols + Rows4LastCol + Cols4LastRow
        RowsCols_Slices    = RowsCols_Slices + Rows4LastCol_Slices + Cols4LastRow_Slices

        return RowsCols, RowsCols_Slices

q,w = slicetester((3, 256, 256), 64, 32)
for i in q:
        print(i)

In [None]:
# test batchify

def batchitester(batch_size, RowsCols, RowsCols_Slices):
    n = mceil(len(RowsCols)/batch_size)
    BatchIndices  = np.array_split(list(range(len(RowsCols))),n,axis=0)
    BatchRowsCols = np.array_split(RowsCols,n,axis=0)
    BatchRowsCols_Slices = np.array_split(RowsCols_Slices,n,axis=0)

    return BatchIndices, BatchRowsCols, BatchRowsCols_Slices

a, s, d = batchitester(2, q, w)
print(a)
print(s)
print(d)

print(int(0.9*len(s)))
for r, c in d:
    print(r, c)

In [None]:
q1, w1 = valid_dataset[0]
q2, w2 = valid_dataset[1]
q3, w3 = valid_dataset[2]

plotter(w1)
plotter(w2)
plotter(w3)