In [None]:
import json
import random
import numpy as np
import rasterio as rio
from pathlib import Path
from functools import partial

In [None]:
data_dir = Path("RawData")
img_dir = data_dir / "RGBIR"
lbl_dir = data_dir / "LABELS"
dsm_dir = data_dir / "NDSM"
aug_dir = Path("Data")
train_dir = aug_dir / "Train_Data"
train_imgd = train_dir / "Images"
train_lbld = train_dir / "Labels"
valid_dir = aug_dir / "Validation_Data"
valid_imgd = valid_dir / "Images"
valid_lbld = valid_dir / "Labels"
test_dir = aug_dir / "Test_Data"
test_imgd = test_dir / "Images"
test_lbld = test_dir / "Predicted_Labels"
train_map = Path('Configs/Train_Map.json')
valid_map = Path('Configs/Validation_Map.json')
test_map = Path('Configs/Test_Map.json')

In [None]:
train_ratio = 0.8
use_dsm = True
img_dim = (6000, 6000)
# channels = (1, 2, 3, 4)

In [None]:
def get_class(rgb_vec):
    assert isinstance(rgb_vec, np.ndarray) and rgb_vec.shape == (3,) and np.issubdtype(rgb_vec.dtype, np.integer), "<rgb_vec> must be of type `numpy.ndarray` of shape (3,) with `dtype` of `numpy.integer`!"
    color_code = tuple(rgb_vec.tolist())
    if color_code == (255, 255, 255):
        return 1
    elif color_code == (0, 0, 255):
        return 2
    elif color_code == (0, 255, 255):
        return 3
    elif color_code == (0, 255, 0):
        return 4
    elif color_code == (255, 255, 0):
        return 5
    elif color_code == (255, 0, 0):
        return 6
    else:
        raise AssertionError("Invalid Color Code")

In [None]:
lbl_dict = dict()
img_dict = dict()
dsm_dict = dict()

for imgf in img_dir.glob('**/*.tif'):
    ik = '_'.join(imgf.stem.split('_')[1:])
    if ik in img_dict.keys():
        raise ValuError("Tile IDs are not Unique! <{}>".format(ik))
    else:
        img_dict[ik] = imgf

for lblf in lbl_dir.glob('**/*.tif'):
    lk = '_'.join(lblf.stem.split('_')[1:])
    if lk in lbl_dict.keys():
        raise ValueError("Tile IDs are not Unique! <{}>".format(lk))
    else:
        lbl_dict[lk] = lblf

for dsmf in dsm_dir.glob('**/*.jpg'):
    dk = '_'.join(dsmf.stem.split('_')[1:])
    if dk in dsm_dict.keys():
        raise ValueError("Tile IDs are not Unique! <{}>".format(dk))
    else:
        dsm_dict[dk] = dsmf

ikeys = set(img_dict.keys())
lkeys = set(lbl_dict.keys())
dkeys = set(dsm_dict.keys())

if use_dsm:
    imd_keys = ikeys.intersection(dkeys)
    common_keys = lkeys.intersection(imd_keys)
    test_keys = list(imd_keys - common_keys)
else:
    common_keys = lkeys.intersection(ikeys)
    test_keys = list(ikeys - common_keys)

common_keys = list(common_keys)
random.shuffle(common_keys)
split_index = int(np.floor(len(common_keys) * train_ratio))
train_keys = common_keys[:split_index]
valid_keys = common_keys[split_index:]

In [None]:
# augs = (
#     ((lambda m : m), 'R00'),
#     (partial(np.rot90, k=1, axes=(1, 2)), 'R01'),
#     (partial(np.rot90, k=2, axes=(1, 2)), 'R02'),
#     (partial(np.rot90, k=3, axes=(1, 2)), 'R03'),
#     (partial(np.flip, axis=1), 'FUD'),
#     (partial(np.flip, axis=2), 'FLR')
# )

dconfig = (
    (train_keys, train_imgd, train_lbld, train_map),
    (valid_keys, valid_imgd, valid_lbld, valid_map),
    (test_keys, test_imgd, None, test_map)
)

data_map = dict()

for (ukeys, idir, ldir, map_f) in dconfig:
    data_map = dict()
    for k in ukeys:
        ipath = img_dict[k]
        if ldir:
            lpath = lbl_dict[k]
        else:
            lpath = None

        with rio.open(ipath, 'r') as isrc:
            imeta = isrc.meta.copy()
#             imeta.pop('crs', None)
#             imeta.pop('transform', None)
            iarray = isrc.read()
            assert iarray.shape[1:] == img_dim
            if use_dsm:
                dpath = dsm_dict[k] 
                with rio.open(dpath, 'r') as dfp:
                    dsm = dfp.read()
                    if dsm.shape[1:] != img_dim:
                        dummy = (0, 0, 0)
                        rwidth = 0, *tuple((np.array(img_dim) - np.array(dsm.shape[1:])).tolist())
                        assert all((rw >= 0 for rw in rwidth))
                        pwidth = [(d, r) for d, r in zip(dummy, rwidth)]
                        dsm = np.pad(array=dsm, pad_width=pwidth, mode='edge')
                    iarray = np.concatenate((iarray, dsm), axis=0)
            impath = idir / ipath.name
            imeta['count'], imeta['height'], imeta['width'] = iarray.shape
            imeta['dtype'] = iarray.dtype
            with rio.open(impath, 'w', **imeta) as idst:
                idst.write(iarray)
            
        if lpath:
            with rio.open(lpath, 'r') as lsrc:
                lmeta = lsrc.meta.copy()
#                 lmeta.pop('crs', None)
#                 lmeta.pop('transform', None)
                larray = lsrc.read()
                assert larray.shape[1:] == img_dim
                larray = np.apply_along_axis(func1d=get_class, axis=0, arr=larray).astype(np.uint8)
                larray = np.stack((larray,), axis=0)
                lbpath = ldir / lpath.name
                lmeta['count'], lmeta['height'], lmeta['width'] = larray.shape
                lmeta['dtype'] = larray.dtype
                with rio.open(lbpath, 'w', **lmeta) as ldst:
                    ldst.write(larray)
                data_map[k] = {
                    'IMAGE': str(impath),
                    'LABEL': str(lbpath)
                }
        else:
            data_map[k] = {
                'IMAGE': str(impath),
                'LABEL': None
            }
    with open(map_f, 'w') as mfp:
        json.dump(data_map, mfp, indent=4)