Skip to content

Commit

Permalink
Merge pull request #84 from mikedorfman/master
Browse files Browse the repository at this point in the history
Update spatial.py to conform to PEP-8 standards, update docstring returns
  • Loading branch information
Leah Wasser committed Oct 30, 2018
2 parents 1eb004b + 9020bba commit 809bf29
Showing 1 changed file with 71 additions and 46 deletions.
117 changes: 71 additions & 46 deletions earthpy/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import geopandas as gpd
import rasterio as rio
from rasterio.mask import mask
import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -52,6 +51,11 @@ def normalized_diff(b1, b2):
----------
b1, b2 : arrays with the same shape
Math will be calculated (b2-b1) / (b2+b1).
Returns
----------
n_diff : ndarray with the same shape as inputs
The element-wise result of (b2-b1) / (b2+b1) with all nan values masked.
"""
if not (b1.shape == b2.shape):
raise ValueError("Both arrays should be of the same dimensions")
Expand All @@ -73,12 +77,22 @@ def stack_raster_tifs(band_paths, out_path, arr_out=True):
A list with paths to the bands you wish to stack. Bands
will be stacked in the order given in this list.
out_path : string
A path with a file name for the output stacked raster
A path with a file name for the output stacked raster
tif file.
arr_out : boolean
A boolean argument to designate what is returned in the stacked
raster tif output.
Returns
----------
If arr_out keyword is True:
tuple: The first value representing the result of src.read() of the stacked array and the second value
representing the result of src.profile of the stacked array.
If arr_out keyword is False:
str : A path with a file name for the output stacked raster tif file.
TODO: Instead of returning a file path when arr_out=False, consider returning None since the out_path is already
an input given by the user. This will make the output type consistent.
"""
# Set default import to read
kwds = {'mode': 'r'}
Expand Down Expand Up @@ -128,14 +142,14 @@ def stack(sources, dest):
raise ValueError("The sources object should be of type: rasterio.DatasetReader")

for ii, ifile in enumerate(sources):
bands = sources[ii].read()
if bands.ndim != 3:
bands = bands[np.newaxis, ...]
for band in bands:
dest.write(band, ii+1)
bands = sources[ii].read()
if bands.ndim != 3:
bands = bands[np.newaxis, ...]
for band in bands:
dest.write(band, ii+1)


def crop_image(raster, geoms, all_touched = True):
def crop_image(raster, geoms, all_touched=True):
"""Crop a single file using geometry objects.
Parameters
Expand Down Expand Up @@ -167,12 +181,12 @@ def crop_image(raster, geoms, all_touched = True):
else:
clip_ext = geoms
# Mask the input image and update the metadata
out_image, out_transform = rio.mask.mask(raster, clip_ext, crop = True, all_touched = all_touched)
out_image, out_transform = rio.mask.mask(raster, clip_ext, crop=True, all_touched=all_touched)
out_meta = raster.meta.copy()
out_meta.update({"driver": "GTiff",
"height": out_image.shape[1],
"width": out_image.shape[2],
"transform": out_transform})
"height": out_image.shape[1],
"width": out_image.shape[2],
"transform": out_transform})
return (out_image, out_meta)


Expand Down Expand Up @@ -247,7 +261,7 @@ def bytescale(data, cmin=None, cmax=None, high=255, low=0):

# TODO: verify colorbar works with the latest matplotlib, and is not too wide

def colorbar(mapobj, size = "3%", pad=0.09, aspect=20):
def colorbar(mapobj, size="3%", pad=0.09, aspect=20):
"""
Adjusts the height of a colorbar to match the axis height.
----------
Expand Down Expand Up @@ -279,7 +293,7 @@ def colorbar(mapobj, size = "3%", pad=0.09, aspect=20):


# Function to plot all layers in a stack
def plot_bands(arr, title = None, cmap = "Greys_r", figsize=(12,12), cols = 3, extent = None):
def plot_bands(arr, title=None, cmap="Greys_r", figsize=(12, 12), cols=3, extent=None):
"""
Plot each layer in a raster stack converted into a numpy array for quick visualization.
Expand All @@ -290,16 +304,18 @@ def plot_bands(arr, title = None, cmap = "Greys_r", figsize=(12,12), cols = 3, e
cols: int the number of columsn you want to plot in
figsize: tuple. the figsize if you'd like to define it. default: (12, 12)
extent: an extent object for plotting
Return
Returns
----------
matplotlib plot of all layers
fig, ax or axs : figure object, axes object
The figure and axes object(s) associated with the plot.
"""
# If the array is 3 dimensional setup grid plotting
if arr.ndim > 2 and arr.shape[0] > 1:
# test if there are enough titles to create plots
if title:
if not (len(title) == arr.shape[0]):
raise ValueError("The number of plot titles should be the same as the number of raster layers in your array.")
if not (len(title) == arr.shape[0]):
raise ValueError("The number of plot titles should be the same " +
"as the number of raster layers in your array.")
# Calculate the total rows that will be required to plot each band
plot_rows = int(np.ceil(arr.shape[0] / cols))
total_layers = arr.shape[0]
Expand All @@ -318,30 +334,31 @@ def plot_bands(arr, title = None, cmap = "Greys_r", figsize=(12,12), cols = 3, e
# This loop clears out the plots for bands 8-9 which are empty
# But you have to populate them in matplotlib when you specify plot rows and cols
for ax in axs_ravel[total_layers:]:
ax.set_axis_off()
ax.set(xticks=[], yticks=[])
ax.set_axis_off()
ax.set(xticks=[], yticks=[])

plt.tight_layout()
return fig, axs
elif arr.ndim == 2 or arr.shape[0] == 1:
# If it's a 2 dimensional array with a 3rd dimension
if arr.shape[0] == 1:
arr = arr[0]
# Plot all bands
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(bytescale(arr), cmap=cmap,
extent = extent)
extent=extent)
if title:
ax.set(title=title)
ax.set(xticks=[], yticks=[])


def plot_rgb(arr, rgb = (0,1,2),
ax = None,
extent = None,
title = "",
figsize = (10,10),
stretch = None,
str_clip = 2):
return fig, ax

def plot_rgb(arr, rgb=(0, 1, 2),
ax=None,
extent=None,
title="",
figsize=(10, 10),
stretch=None,
str_clip=2):
"""
Plot each layer in a raster stack converted into a numpy array for quick visualization.
Expand All @@ -358,8 +375,9 @@ def plot_rgb(arr, rgb = (0,1,2),
Returns
----------
ax : matplotlib Axes
Axes with plot of 3 band image.
fig, ax : figure object, axes object
The figure and axes object associated with the 3 band image. If the ax keyword is specified,
the figure return will be None.
"""

if len(arr.shape) != 3:
Expand Down Expand Up @@ -390,18 +408,21 @@ def plot_rgb(arr, rgb = (0,1,2),

# Then plot. Define ax if it's default to none
if ax is None:
fig, ax = plt.subplots(figsize = figsize)
ax.imshow(rgb_bands, extent = extent)
fig, ax = plt.subplots(figsize=figsize)
else:
fig = None
ax.imshow(rgb_bands, extent=extent)
ax.set_title(title)
ax.set(xticks=[], yticks=[])
return fig, ax



def hist(arr,
title = None,
colors = ["purple"],
figsize=(12,12), cols = 2,
bins = 20):
title=None,
colors=["purple"],
figsize=(12, 12), cols=2,
bins=20):
"""
Plot histogram each layer in a raster stack converted into a numpy array for quick visualization.
Expand All @@ -413,17 +434,19 @@ def hist(arr,
cols: int the number of columsn you want to plot in
bins: the number of bins to calculate for the histogram
figsize: tuple. the figsize if you'd like to define it. default: (12, 12)
Return
Returns
----------
matplotlib plot of all layers
fig, ax or axs : figure object, axes object
The figure and axes object(s) associated with the histogram.
"""

# If the array is 3 dimensional setup grid plotting
if arr.ndim > 2:
# Test if there are enough titles to create plots
if title:
if not (len(title) == arr.shape[0]):
raise ValueError("The number of plot titles should be the same as the number of raster layers in your array.")
if not (len(title) == arr.shape[0]):
raise ValueError("The number of plot titles should be the same " +
"as the number of raster layers in your array.")
# Calculate the total rows that will be required to plot each band
plot_rows = int(np.ceil(arr.shape[0] / cols))
total_layers = arr.shape[0]
Expand All @@ -441,8 +464,9 @@ def hist(arr,
ax.set_title(title[i])
# Clear additional axis elements
for ax in axs_ravel[total_layers:]:
ax.set_axis_off()
#ax.set(xticks=[], yticks=[])
ax.set_axis_off()
#ax.set(xticks=[], yticks=[])
return fig, axs
elif arr.ndim == 2:
# Plot all bands
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -452,7 +476,7 @@ def hist(arr,
color=colors[0])
if title:
ax.set(title=title[0])

return fig, ax

def hillshade(arr, azimuth=30, angle_altitude=30):
"""
Expand All @@ -476,7 +500,8 @@ def hillshade(arr, azimuth=30, angle_altitude=30):
azimuthrad = azimuth*np.pi/180.
altituderad = angle_altitude*np.pi/180.

shaded = np.sin(altituderad)*np.sin(slope) + np.cos(altituderad)*np.cos(slope)*np.cos((azimuthrad - np.pi/2.) - aspect)
shaded = (np.sin(altituderad)*np.sin(slope) +
np.cos(altituderad) * np.cos(slope) * np.cos((azimuthrad - np.pi / 2.) - aspect))

return 255*(shaded + 1)/2

Expand Down

0 comments on commit 809bf29

Please sign in to comment.