In [None]:
!pip install --upgrade --no-cache-dir setuptools==57.5.0
!pip install python-dotenv gdal==3.6.4

In [None]:
!pip install pillow matplotlib ipynb pysheds
!pip install "numexpr>=2.7.3"

In [None]:
import os
import numpy as np
import pathlib
import matplotlib.pyplot as plt
import cv2
import concurrent
from importlib import reload
# reload(lib.generate_training_data)
from lib.generate_training_data import generate_sketch

def run(options, async_exec):
    dems = []
    sketches = []
    tile_file_paths = [f for f in pathlib.Path(options["input_path"]).glob('*.tif')]

    if async_exec:
        with concurrent.futures.ProcessPoolExecutor() as process_pool:
            futures = []
            for tile_file_path in tile_file_paths:
                tile_file_path = str(tile_file_path)
                futures.append(process_pool.submit(generate_sketch, tile_file_path, options))
    
            print(f'Processing {len(futures)} images...')
            for future in concurrent.futures.as_completed(futures):
                sketch_id, dem, sketch = future.result()
                dems.append(dem)
                sketches.append(sketch)
                print(f'Processed {sketch_id}')
    else:
        for i, tile_file_path in enumerate(tile_file_paths):
            tile_file_path = str(tile_file_path)
            sketch_id = os.path.basename(tile_file_path)
            progress = f'[{i+1}/{len(tile_file_paths)}]'
            print(f'{progress} Processing {sketch_id} ...')

            sketch_id, dem, sketch = generate_sketch(tile_file_path, options)
            dems.append(dem)
            sketches.append(sketch)
            print(f'{progress} Processed {sketch_id}')
            
    training_input = np.array(sketches)
    training_output = np.array(dems)
    np.savez(options['output_path'], x=training_input, y=training_output)
        
    print('Done!')

if __name__ == "__main__":
    run({
        "input_path": "data/earthdata_tiles",
        "output_path": "data/training_data.npz",
        "flow_threshold": 230,
        "dem_target_size": 3600,
        "sketch_target_size": 512
    }, True)


In [None]:
# Visualize training data
import matplotlib.pyplot as plt
import numpy as np

training_data = np.load('training_data.npz')
sketches = training_data['x']
dems = training_data['y']

for index in range(len(sketches)):
    plt.figure(figsize=(10, 20))
    plt.subplot(1, 2, 1)
    plt.imshow(sketches[index])
    plt.subplot(1, 2, 2)
    plt.imshow(dems[index], cmap='viridis')
    plt.show()
