In [None]:
import os
from pathlib import Path
from rocksdbutils import *
import xarray as xr

class AI4BNormal_S2(object):
    """
    class for Normalization of images, per channel, in format CHW 
    """
    def __init__(self):

        self._mean_s2 = np.array([5.4418573e+02, 7.6761194e+02, 7.1712860e+02, 2.8561428e+03 ]).astype(np.float32) 
        self._std_s2  = np.array( [3.7141626e+02, 3.8981952e+02, 4.7989127e+02 ,9.5173022e+02]).astype(np.float32) 

    def __call__(self,img):

        temp = img.astype(np.float32)
        temp2 = temp.T
        temp2 -= self._mean_s2
        temp2 /= self._std_s2

        temp = temp2.T
        return temp


    
import albumentations as A
from albumentations.core.transforms_interface import  ImageOnlyTransform
class TrainingTransformS2(object):
    # Built on Albumentations, this provides geometric transformation only  
    def __init__(self,  prob = 1., mode='train', norm = AI4BNormal_S2() ):
        self.geom_trans = A.Compose([
                    A.RandomCrop(width=128, height=128, p=1.0),  # Always apply random crop
                    A.OneOf([
                        A.HorizontalFlip(p=1),
                        A.VerticalFlip(p=1),
                        A.ElasticTransform(p=1), # VERY GOOD - gives perspective projection, really nice and useful - VERY SLOW   
                        A.GridDistortion(distort_limit=0.4,p=1.),
                        A.ShiftScaleRotate(shift_limit=0.25, scale_limit=(0.75,1.25), rotate_limit=180, p=1.0), # Most important Augmentation   
                        ],p=1.)
                    ],
            additional_targets={'imageS1': 'image','mask':'mask'},
            p = prob)
        if mode=='train':
            self.mytransform = self.transform_train
        elif mode =='valid':
            self.mytransform = self.transform_valid
        else:
            raise ValueError('transform mode can only be train or valid')
            
            
        self.norm = norm
        
    def transform_valid(self, data):
        timgS2, tmask = data
        if self.norm is not None:
            timgS2 = self.norm(timgS2)
        
        tmask= tmask 
        return timgS2,  tmask.astype(np.float32)

    def transform_train(self, data):
        timgS2, tmask = data
        
        if self.norm is not None:
            timgS2 = self.norm(timgS2)

        
        tmask= tmask 
        tmask = tmask.astype(np.float32)
        # Special treatment of time series
        c2,t,h,w = timgS2.shape
        #print (c2,t,h,w)              
        timgS2 = timgS2.reshape(c2*t,h,w)
        result = self.geom_trans(image=timgS2.transpose([1,2,0]),
                                 mask=tmask.transpose([1,2,0]))
        timgS2_t = result['image']
        tmask_t  = result['mask']
        timgS2_t = timgS2_t.transpose([2,0,1])
        tmask_t = tmask_t.transpose([2,0,1])
        
        c2t,h2,w2 = timgS2_t.shape

        
        timgS2_t = timgS2_t.reshape(c2,t,h2,w2)
        return timgS2_t,  tmask_t
    def __call__(self, *data):
        return self.mytransform(data)



def getFilelist(originpath, ftyp):
    files = os.listdir(originpath)
    out   = []
    for i in files:
        if i.split('.')[-1] in ftyp:
            if originpath.endswith('/'):
                out.append(originpath + i)
            else:
                out.append(originpath + '/' + i)
        # else:
        #     print("non-matching file - {} - found".format(i.split('.')[-1]))
    return out

img_size = 256

## create metadata file
metadata = {
    'inputs': {
        'inputs_shape': (6, img_size, img_size),  
        'inputs_dtype': np.float32      
    },
    'labels': {
        'labels_shape': (4, img_size, img_size),          
        'labels_dtype': np.float32
    }
}

## create list of images and labels
files = getFilelist('/home/repos/output/test_tiffs/', '.tif')
imgs = [f for f in files if 'img' in f]
labs = [f for f in files if 'label' in f]

imgs.sort()
labs.sort()

# --> Todo: built in checker for right list alignment
tif_paths = []
for i in range(len(imgs)):
    tif_paths.append((imgs[i], labs[i]))


## define function to load imgs and labs
def names2raster_function(names):
    image_path, label_path = names
    # Laden Sie das Bild
    ds = xr.open_dataset(image_path)
    image = np.asarray(ds['band_data'].values)
    # Laden Sie das Label (z.B. als Ganzzahl)
    ds = xr.open_dataset(label_path)
    label = np.asarray(ds['band_data'].values)
    label[np.isnan(label)] = 0
    return [image, label] # return [image[0,:,:], label[0]]

## create db
output_dir = '/home/repos/ssg2/Notebooks/testi_big.db'
os.makedirs(output_dir, exist_ok=True)

rasters2rocks = Rasters2RocksDB(
    lstOfTuplesNames=tif_paths,            
    names2raster_function=names2raster_function,  
    metadata=metadata,                       
    flname_prefix_save=output_dir,           
    batch_size=4,
    transform=TrainingTransformS2(),
    stride_divisor=2,                           
    train_split=0.9,                         
    Filter=img_size,
    split_type='sequential'                  
)

rasters2rocks.create_dataset()



In [3]:
import os
from pathlib import Path
from rocksdbutils import *
import xarray as xr

db = RocksDBDataset('/home/output/rocks_db/ES.db/train.db')
output_dir2 = '/home/output/tiffs_back/'

for i in range(10):
    aa =  db[i][0]
    print(aa.shape)
    print(aa[1,1,50,50])

# def export_rock_to_tif(arr, src, path, name):
#     with rasterio.open(
#             path + name + '.tif',
#             'w',
#             crs=src.crs,
#             nodata=None, # change if data has nodata value
#             transform=src.transform,
#             driver='GTiff',
#             height=arr.shape[1],
#             width=arr.shape[2],
#             count=arr.shape[0],
#             dtype=arr.dtype
#         ) as dst:
#             for i in range(arr.shape[0]):
#                 dst.write(arr[i], i + 1)

# db
# for i in range(10):
#     export_rock_to_tif(db[i][0], rasterio.open('/home/repos/output/test_tiffs/test_img_0.tif'), output_dir2, 'img_' + str(i))

ValueError: cannot reshape array of size 393216 into shape (5,6,128,128)