In [None]:
import geopyspark as gps
from pyspark import SparkContext
from shapely.geometry import mapping, shape, asShape, MultiPoint, MultiLineString
from geonotebook.wrappers import TMSRasterData, GeoJsonData
import pyproj
from shapely.ops import transform
from functools import partial
import os, urllib.request, json

In [None]:
# Set up our spark context
conf = gps.geopyspark_conf(appName="Landsat") \
          .setMaster("local[*]") \
          .set(key='spark.ui.enabled', value='true') \
          .set(key="spark.driver.memory", value="8G") \
          .set("spark.hadoop.yarn.timeline-service.enabled", False)
sc = SparkContext(conf=conf)

In [None]:
# Grab data for Nevada
state_name, county_name = "NV", "Mineral"
def get_state_shapes(state, county):
    project = partial(
        pyproj.transform,
        pyproj.Proj(init='epsg:4326'),
        pyproj.Proj(init='epsg:3857'))

    state_url = "https://raw.githubusercontent.com/johan/world.geo.json/master/countries/USA/{}.geo.json".format(state)
    county_url = "https://raw.githubusercontent.com/johan/world.geo.json/master/countries/USA/{}/{}.geo.json".format(state,county)
    read_json = lambda url: json.loads(urllib.request.urlopen(url).read().decode("utf-8"))
    state_ll = shape(read_json(state_url)['features'][0]['geometry'])
    state_wm = transform(project, state_ll)
    county_ll = shape(read_json(county_url)['features'][0]['geometry'])
    county_wm = transform(project, county_ll)
    return (state_ll, state_wm, county_ll, county_wm)

(state_ll, state_wm, county_ll, county_wm) = get_state_shapes(state_name, county_name) 

In [None]:
import time
def timeit(fn):
    t0 = time.time()
    x = fn()
    t1 = time.time()

    total = t1-t0
    print("Took {}".format(total))
    return x

In [None]:
nlcd_layer_name = "nlcd-zoomed-256"

# View NLCD from GeoTrellis Catalog

In [None]:

tms_server = gps.TMS.build(("s3://datahub-catalogs-us-east-1", nlcd_layer_name), 
                           display=gps.ColorMap.nlcd_colormap())


In [None]:
M.add_layer(TMSRasterData(tms_server), name="nlcd")
M.set_center(-120.32, 47.84, 7)

# Read State NLCD Tiles

In [None]:
p = county_ll.centroid
M.set_center(p.x, p.y, 9)

## Work with County RDD

In [None]:
def get_layer(): 
    return gps.query("s3://datahub-catalogs-us-east-1", 
                      nlcd_layer_name, 
                      layer_zoom=13, 
                      query_geom=state_wm,
                      num_partitions=100).cache()
layer = timeit(get_layer)

In [None]:
layer.get_min_max()

In [None]:
(layer + 10).get_min_max()

In [None]:
for l in M.layers:
    M.remove_layer(l)

In [None]:
pyramid = layer.repartition(100).pyramid()


In [None]:
tms_server = gps.TMS.build(pyramid, 
                           display=gps.ColorMap.nlcd_colormap())

In [None]:
M.add_layer(TMSRasterData(tms_server), name="nlcd")
#M.add_layer(vd, name=name)

In [None]:
M.add_layer(GeoJsonData(mapping(state_ll)), name="poly")

In [None]:
M.remove_layer(M.layers[0])

In [None]:
masked = layer.mask(geometries=state_wm)
masked_pyramid = masked.repartition(100).pyramid()
tms_server = gps.TMS.build(masked_pyramid, 
                           display=gps.ColorMap.nlcd_colormap())
M.add_layer(TMSRasterData(tms_server), name="nlcd")

In [None]:
for l in M.layers:
    M.remove_layer(l)

In [None]:
cultivated_land_colormap = gps.ColorMap.build(breaks={82: 0x00FF00FF},
                                              classification_strategy=gps.ClassificationStrategy.EXACT,
                                              fallback=0x00000000)    
tms_server = gps.TMS.build(masked_pyramid, 
                           display=cultivated_land_colormap)
M.add_layer(TMSRasterData(tms_server), name="nlcd")

In [None]:
M.remove_layer(M.layers[0])

In [None]:
rdd = masked.to_numpy_rdd()
rdd.first()

In [None]:
import numpy as np
def get_counts(tile):
    values, counts = np.unique(tile.cells.flatten(), return_counts=True)
    d = {}
    for v, c in zip(values, counts):
        if v != -128: # Remove NoData
            d[v] = c
    return d

def merge_counts(d1, d2):
    d = {}
    for k in set(d1.keys()).union(set(d2.keys())):
        v = 0
        if k in d1:
            v += d1[k]
        if k in d2:
            v += d2[k]
        d[k] = v
    return d

counts = rdd.map(lambda x: get_counts(x[1])).reduce(merge_counts)
counts

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

labels = { 0: 'NoData',
          11: 'Open Water',
          12: 'Perennial Ice/Snow',
          21: 'Developed, Open Space',
          22: 'Developed, Low Intensity',
          23: 'Developed, Medium Intensity',
          24: 'Developed High Intensity',
          31: 'Barren Land (Rock/Sand/Clay)',
          41: 'Deciduous Forest',
          42: 'Evergreen Forest ',
          43: 'Mixed Forest',
          52: 'Shrub/Scrub',
          71: 'Grassland/Herbaceous',
          81: 'Pasture/Hay',
          82: 'Cultivated Crops',
          90: 'Woody Wetlands',
          95: 'Emergent Herbaceous Wetlands'}
named_counts = {}
for k in counts:
    named_counts[labels[k]] = counts[k]

df = pd.DataFrame.from_dict(named_counts,  orient='index')

In [None]:
df

In [None]:
plt.figure()
df.plot.bar(legend=False)
plt.show()