In [None]:
import geopandas
import dask_geopandas
import matplotlib.pyplot as plt

In [None]:
path = "https://geodata.ucdavis.edu/gadm/gadm4.0/gpkg/gadm40_USA.gpkg"
path = "~/src/project/cmr-bigstac-prototype/bigstac/scripts_explore/gadm40_USA.gpkg"

usa = geopandas.read_file(path, layer="ADM_1")
usa.head()

In [None]:
usa = usa[["NAME_1", "geometry"]].rename(columns={"NAME_1": "State"})
usa.info()

In [None]:
usa.plot()

In [None]:
us_cont = usa.cx[-150:-50, 20:50]
us_cont.info()

In [None]:
us_cont.plot(facecolor="none", linewidth=0.5, edgecolor="red")

In [None]:
d_gdf = dask_geopandas.from_geopandas(us_cont, npartitions=4)
d_gdf

In [None]:
d_gdf.calculate_spatial_partitions() # convex hull
d_gdf.spatial_partitions

In [None]:
fig, ax = plt.subplots(1,1, figsize=(12,6))
us_cont.plot(ax=ax)
d_gdf.spatial_partitions.plot(ax=ax, cmap="tab20", alpha=0.5)
ax.set_axis_off()
plt.show()

Now try the spatial sorting methods

In [None]:
hilbert = d_gdf.spatial_shuffle(by="hilbert")
morton = d_gdf.spatial_shuffle(by="morton")
geohash = d_gdf.spatial_shuffle(by="geohash")

In [None]:
fig, axes = plt.subplots(nrows=1,ncols=3, figsize=(25,12))
ax1, ax2, ax3 = axes.flatten()

for ax in axes:
    us_cont.plot(ax=ax)

hilbert.spatial_partitions.plot(ax=ax1, cmap="tab20", alpha=0.5)
morton.spatial_partitions.plot(ax=ax2, cmap="tab20", alpha=0.5)
geohash.spatial_partitions.plot(ax=ax3, cmap="tab20", alpha=0.5)

[axi.set_axis_off() for axi in axes.ravel()]

ax1.set_title("Hilbert", size=16)
ax2.set_title("Morton", size=16)
ax3.set_title("Geohash", size=16)

plt.show()

In [None]:
hilbert20 = d_gdf.spatial_shuffle(by="hilbert", npartitions=20)
geohash20 = d_gdf.spatial_shuffle(by="geohash", npartitions=20)

In [None]:
fig, axes = plt.subplots(nrows=2,ncols=3, figsize=(25,20))
ax1, ax2, ax3, ax4, ax5 = axes[0, 0], axes[0, 1], axes[0, 2], axes[1, 0], axis[1, 1]

for ax in axes:
    us_cont.plot(ax=ax)

#d_gdf.spatial_partitions.plot(ax=ax1, cmap="tab20", alpha=0.5)
hilbert.spatial_partitions.plot(ax=ax2, cmap="tab20", alpha=0.5)
hilbert20.spatial_partitions.plot(ax=ax3, cmap="tab20", alpha=0.5)
geohash20.spatial_partitions.plot(ax=ax4, cmap="tab20", alpha=0.5)

selected = geohash20.spatial_partitions[geohash20.spatial_partitions.index == 1]
selected.plot(ax=ax4, alpha=0.5, color="red")

selected = geohash20.spatial_partitions[geohash20.spatial_partitions.index == 3]
selected.plot(ax=ax5, alpha=0.5, color="green")

[axi.set_axis_off() for axi in axes.ravel()]

ax1.set_title("No spatial shuffle, with 4 partitions", size=16)
ax2.set_title("Spatial shuffle using default npartitions", size=16)
ax3.set_title("Spatial shuffle using 20 partitions", size=16)
ax4.set_title("geohash 20 - 1")
ax5.set_title("geohash 20 - 2")

plt.show()

In [None]:
#print(hilbert20.spatial_partitions)
print(geohash20.spatial_partitions)

In [None]:
print(geohash20.spatial_partitions[0])

#computed_partition = geohash20.spatial_partitions[0].compute()


In [None]:
d_gdf = dask_geopandas.from_geopandas(us_cont, npartitions=4)
geohash20 = d_gdf.spatial_shuffle(by="geohash", npartitions=20)
counts = {}
for i, part in enumerate (geohash20.partitions):
  first = geohash20.get_partition(i).compute()
  #print(geohash20.spatial_partitions[i])
  for index, row in first.iterrows():
      counts[row.State] = counts.get(row.State, 0) + 1
      print(f"{row.State}", end=",")
  print()
print(len(counts))