-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test GPU apply #74
Labels
enhancement
New feature or request
Comments
See the series class here for more details. |
Custom classes must have the following structure: class Custom(gw.TimeModule):
def __init__(self):
super(Custom, self).__init__()
def calculate(self, array):
"""
Args:
array: JAX DeviceArray shaped [time x bands x rows x columns]
"""
# The returned array should have dimensions [output band count x rows x columns], where
res = <>
return res Override the output data type, band count, and compression class Custom(gw.TimeModule):
def __init__(self):
super(Custom, self).__init__()
self.dtype = 'uint16'
self.count = 2
self.compress = 'lzw'
def calculate(self, array):
"""
Args:
array: JAX DeviceArray shaped [time x bands x rows x columns]
"""
# If the array is shaped [20 x 1 x 100 x 100]
res = array.mean(axis=0).squeeze()
return res Using the band dictionary import geowombat as gw
import jax.numpy as jnp
class Custom(gw.TimeModule):
def __init__(self):
super(Custom, self).__init__()
def calculate(self, array):
"""
Args:
array: JAX DeviceArray shaped [time x bands x rows x columns]
"""
s1 = (slice(0, None), slice(band_dict['red'], band_dict['red']+1), slice(0, None), slice(0, None))
s2 = (slice(0, None), slice(band_dict['nir'], band_dict['nir']+1), slice(0, None), slice(0, None))
ndvi = (array[s1] - array[s2]) / (array[s1] + array[s2])
# Mean NDVI over time
res = jnp.nanmean(ndvi, axis=0).squeeze()
return res
with gw.series([...], band_names=['red', 'nir']) as src:
src.apply(Custom(), 'outfile.tif', bands=[3, 4]) |
I created a page for time series on GPUs here. |
Very excited to try this out!
…On Sun, Jul 4, 2021 at 1:59 AM Jordan ***@***.***> wrote:
I created a page for time series on GPUs here
<https://github.com/jgrss/geowombat/blob/jax/doc/source/gpu.rst>.
—
You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub
<#74 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABHR6VHCD3P75IYWSIAMNY3TV72EXANCNFSM47XUI4GA>
.
|
I will close this as this is now implemented. However, feel free to reopen. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@rdenham if you have some spare time, test the new
series
class from the jax branch. I removed xarray and dask dependencies, so opening large images should be faster. Additionally, computation is on the GPU.To test it:
Install
geowombat
from thejax
branch.Install
jax
with GPU support (command below is for CUDA 11.1)Simple example to compute the mean over time:
Create a custom class
The text was updated successfully, but these errors were encountered: