In [None]:
import numpy as np
from math import ceil
import rasterio as rio
from pathlib import Path
from itertools import product
from tqdm.notebook import tqdm
from rasterio import windows as windows

In [None]:
src_dir = Path("Data/Validation_Data/Labels")
dst_dir = Path("Data/Validation_Data/RGB_Labels")

In [None]:
color_map = {
    0: np.array([255, 255, 255]),
    1: np.array([0, 0, 255]),
    2: np.array([0, 255, 255]),
    3: np.array([0, 255, 0]),
    4: np.array([255, 255, 0]),
    5: np.array([255, 0, 0])
}

In [None]:
def generate_windows(img_width, img_height, width, height):
    wins = list()        
    offsets = product(range(0, img_width, width), range(0, img_height, height))
    big_window = windows.Window(col_off=0, row_off=0, width=img_width, height=img_height)
    for col_off, row_off in offsets:
        window = windows.Window(
            col_off=col_off,
            row_off=row_off,
            width=width,
            height=height
        )
        win = window.intersection(big_window)
        wins.append(win)
    return wins

In [None]:
for imf in tqdm(list(src_dir.glob('*.tif'))):
    with rio.open(imf, 'r') as src:
        meta = src.meta.copy()
        ih, iw = meta['height'], meta['width']
        assert src.count == 1
        wins = generate_windows(img_width=iw, img_height=ih, width=600, height=600)
        meta['count'] = 3
        meta['dtype'] = np.uint8
        dst_path = dst_dir / (imf.stem + '_RGB' + imf.suffix)
        with rio.open(dst_path, 'w', **meta) as dst:
            for w in tqdm(wins):
                img_array = src.read(indexes=1, window=w) - 1
                lookup = np.zeros((255, 3), dtype=np.uint8)
                clist = list()
                for u in np.unique(img_array, return_counts=False):
                    clist.append(color_map[u])
                colors = np.stack(clist, axis=0)
                lookup[np.unique(img_array)] = colors
                rgb_img = lookup[img_array]
                rgb_img = np.moveaxis(rgb_img, -1, 0)
                dst.write(rgb_img, window=w)