In [None]:
import rasterio
import rasterio.features
import rasterio.warp
import geopyspark as gps
import numpy as np
import matplotlib.pyplot as plt

from pyspark import SparkContext
from geonotebook.wrappers import TMSRasterData
from osgeo import osr

import os
import math
import boto3

%matplotlib inline

In [None]:
conf = gps.geopyspark_conf("local[*]", "SRTM Ingest") \
          .set("spark.dynamicAllocation.enabled", False) \
          .set("spark.ui.enabled",True) \
          .set("spark.hadoop.yarn.timeline-service.enabled", False)

In [None]:
sc = SparkContext(conf=conf)

In [None]:
file_names = ['N00E006.hgt', 'N00E009.hgt', 'N00E010.hgt', 'N00E011.hgt', 'N00E012.hgt', 'N00E013.hgt', 'N00E014.hgt', 'N00E015.hgt', 'N00E016.hgt', 'N00E017.hgt']
# file_names = file_names[0:2]
print(len(file_names))
print(file_names[0:10])

In [None]:
def get_metadata(uri):
    if "GDAL_DATA" not in os.environ:
        os.environ["GDAL_DATA"]="/home/hadoop/.local/lib/python3.4/site-packages/fiona/gdal_data"
    
    try:
        with rasterio.open(uri) as dataset:
            bounds = dataset.bounds
            height = dataset.height
            width = dataset.width
            crs = dataset.get_crs()
            srs = osr.SpatialReference()
            srs.ImportFromWkt(crs.wkt)
            proj4 = srs.ExportToProj4()
            # ws = [w for (ij, w) in dataset.block_windows()]
            tile_cols = (int)(math.ceil(width/512)) * 512
            tile_rows = (int)(math.ceil(height/512)) * 512
            ws = [((x, min(width-1,x + 512)), (y, min(height-1,y + 512))) for x in range(0, tile_cols, 512) for y in range(0, tile_rows, 512)]
    except:
            ws = []
            
    def windows(uri, ws):
        for w in ws:
            ((row_start, row_stop), (col_start, col_stop)) = w

            left  = bounds.left + (bounds.right - bounds.left)*(float(col_start)/width)
            right = bounds.left + (bounds.right - bounds.left)*(float(col_stop)/ width)
            bottom = bounds.top + (bounds.bottom - bounds.top)*(float(row_stop)/height)
            top = bounds.top + (bounds.bottom - bounds.top)*(float(row_start)/height)
            extent = gps.Extent(left,bottom,right,top)
                
            new_line = {}
            new_line['uri'] = uri
            new_line['window'] = w
            new_line['projected_extent'] = gps.ProjectedExtent(extent=extent, proj4=proj4)
            yield new_line
    
    return [i for i in windows(uri, ws)]


In [None]:
def get_data(line):
    new_line = line.copy()

    with rasterio.open(line['uri']) as dataset:
        new_line['data'] = dataset.read(1, window=line['window'])
        new_line.pop('window')
        new_line.pop('uri')
    
    return new_line

In [None]:
def filename_to_data(filename):
    full_filename = "/vsicurl/https://s3.amazonaws.com/mrgeo-source/srtm-v3-30/{}".format(filename)
    data = [get_data(line) for line in get_metadata(full_filename)]
    return data

In [None]:
rdd0 = sc.parallelize(file_names)
rdd1 = rdd0.flatMap(filename_to_data)
print(rdd1.count())

In [None]:
rdd2 = rdd1.groupBy(lambda line: line['projected_extent']) # XXX

In [None]:
def make_tiles(line):
    projected_extent = line[0]
    array = np.array([l['data'] for l in line[1]])
    tile = gps.Tile.from_numpy_array(array, no_data_value=0)
    return (projected_extent, tile)


In [None]:
rdd3 = rdd2.map(make_tiles)

In [None]:
raster_layer = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL, rdd3)

In [None]:
tiled_raster_layer = raster_layer.tile_to_layout(layout = gps.GlobalLayout(), target_crs=3857)

In [None]:
pyramid = tiled_raster_layer.pyramid()

In [None]:
for layer in pyramid.levels.values():
    gps.write("file:///tmp/dg-srtm/", "srtm-geopyspark-1", layer)

In [None]:
# pyramid2 = gps.Pyramid([gps.query("file:///tmp/dg-srtm", "srtm-geopyspark", layer_zoom=n, num_partitions=1024*16) for n in range(0,13+1)])

In [None]:
histogram = pyramid.get_histogram()
color_map = gps.ColorMap.build(breaks=histogram, colors='viridis')

In [None]:
tms = gps.TMS.build(source=pyramid, display=color_map)

In [None]:
M.add_layer(TMSRasterData(tms), name="Weighted Layer")