In [None]:
!pip install --no-cache-dir 'GDAL[numpy]==3.6.2'

In [None]:
!pip install pillow matplotlib ipynb pysheds
!pip install numexpr>=2.7.3

In [None]:
import os
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from matplotlib import colors
from pysheds.grid import Grid
import rasterio
import cv2
import skimage

def run():

    dems = []
    sketches = []

    tile_file_paths = [f for f in pathlib.Path(options["input_path"]).glob('*.tif')]
    for i, tile_file_path in enumerate(tile_file_paths):
        tile_file_path = str(tile_file_path)
        progress = f'[{i+1}/{len(tile_file_paths)}]'
        print(f'{progress} Processing {tile_file_path} ...')
        
        sketch = generate_sketch(tile_file_path)
        dems.append(cv2.imread(tile_file_path, cv2.IMREAD_UNCHANGED))
        sketches.append(sketch)
        
    training_input = np.array(sketches)
    training_output = np.array(dems)
    np.savez(options['output_path'], x=training_input, y=training_output)
        
    print('Done!')

def generate_sketch(input_file_path):
    if options["plot_dem"]:
        im = plt.imshow(plt.imread(input_file_path))
        plt.show()

    grid, dem = preprocess_dem(input_file_path)
    
    sea = extract_sea(grid, dem)
    land_mask = (~sea.astype(bool)).astype(np.uint8)
    rivers = extract_rivers(grid, dem, land_mask)
    ridges = extract_ridges(grid, dem, land_mask)

    sketch = cv2.merge([sea, rivers, ridges]) # BGR

    return sketch

def preprocess_dem(input_file_path):

    original_grid = Grid.from_raster(input_file_path)
    raster = cv2.imread(input_file_path, cv2.IMREAD_UNCHANGED)

    raster = blur(raster, options["dem_blurring_iterations"])

    dem = rasterio.open(
        '/tmp/new.tif',
        'w',
        driver='GTiff',
        height=raster.shape[0],
        width=raster.shape[1],
        count=1,
        dtype='int16',
        crs=original_grid.crs,
        transform=original_grid.affine,
        nodata=original_grid.nodata
    )
    dem.write(raster, 1)
    dem.close()

    grid = Grid.from_raster('/tmp/new.tif')
    dem = grid.read_raster('/tmp/new.tif')
    
    return grid, dem

def extract_sea(grid, dem):
    # print('Extracting coastline...')

    sea = np.zeros_like(dem, dtype=np.uint8)
    sea[dem <= 0] = 255

    if options["plot_sea"]:
        cmap = plt.colormaps.get_cmap("Greys").copy()
        cmap.set_bad(color='black')
        plt.imshow(sea, interpolation='none', cmap=cmap)
        plt.show()
        
    return sea

def extract_rivers(grid, dem, land_mask):
    # print('Extracting rivers...')

    rivers = extract_flow(grid, dem, land_mask)

    if options["plot_rivers"]:
        plt.imshow(rivers)
        plt.show()

    return rivers


def extract_ridges(grid, dem, land_mask):
    # print('Extracting ridges...')

    dem = dem.max() - dem

    ridges = extract_flow(grid, dem, land_mask)

    if options["plot_ridges"]:
        plt.imshow(ridges)
        plt.show()

    return ridges

def extract_flow(grid, dem, land_mask):
    conditioned_dem = condition_dem(grid, dem)

    direction_map = (64, 128, 1, 2, 4, 8, 16, 32)
    flow_direction = grid.flowdir(conditioned_dem, dirmap=direction_map)
    accumulation = grid.accumulation(flow_direction, dirmap=direction_map)

    flow = accumulation
    flow = np.log(accumulation + 1)
    flow = (flow - np.amin(flow)) / (np.amax(flow) - np.amin(flow))
    flow = np.array(flow * 255, dtype=np.uint8) # Normalize to [0,255]

    _, flow = cv2.threshold(flow, options["flow_threshold"], 255, cv2.THRESH_BINARY)
    flow *= land_mask

    # flow = skimage.morphology.skeletonize(flow).astype(np.uint8) * 255
    # flow = smoothen(flow, 7, 16, -4)

    return flow

def smoothen(data, blur, sharp1, sharp2):
    blurred = cv2.GaussianBlur(data, (blur, blur), 0)
    unsharp_mask = cv2.subtract(data, blurred)
    sharp = cv2.addWeighted(data, sharp1, unsharp_mask, sharp2, 0)
    return sharp

def condition_dem(grid, dem):
    pit_filled_dem = grid.fill_pits(dem)
    flooded_dem = grid.fill_depressions(pit_filled_dem)
    inflated_dem = grid.resolve_flats(flooded_dem)
    return inflated_dem

def blur(data, iterations):
    for i in range(iterations):
        data = downsample(data)
    for i in range(iterations):
        data = upsample(data)
    return data

def upsample(data):
    return cv2.pyrUp(data, dstsize=(
        data.shape[1] * 2,
        data.shape[0] * 2
    ))

def downsample(data):
    return cv2.pyrDown(data, dstsize=(
        data.shape[1] // 2,
        data.shape[0] // 2
    ))

options = {
    "input_path": "earthdata_tiles",
    "output_path": "training_data.npz",
    "recreate": True,
    "dem_blurring_iterations": 5,
    "flow_threshold": 235,
    "plot_dem": False,
    "plot_sea": False,
    "plot_rivers": False,
    "plot_ridges": False,
    "plot_sketch": False,
    "save_sketch": True
}

run()


In [None]:
# Visualize training data

training_data = np.load('training_data.npz')
input = training_data['x']
output = training_data['y']

for index in range(len(input)):
    plt.figure(figsize=(10, 20))
    plt.subplot(1, 2, 1)
    plt.imshow(input[index], cmap='viridis')
    plt.subplot(1, 2, 2)
    plt.imshow(output[index], cmap='viridis')
    plt.show()
