Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test GPU apply #74

Closed
jgrss opened this issue Jul 3, 2021 · 5 comments
Closed

Test GPU apply #74

jgrss opened this issue Jul 3, 2021 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@jgrss
Copy link
Owner

jgrss commented Jul 3, 2021

@rdenham if you have some spare time, test the new series class from the jax branch. I removed xarray and dask dependencies, so opening large images should be faster. Additionally, computation is on the GPU.

To test it:

Install geowombat from the jax branch.

Install jax with GPU support (command below is for CUDA 11.1)

pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Simple example to compute the mean over time:

import geowombat as gw

Create a custom class

class TemporalMean(gw.TimeModule):

    def __init__(self):        
        super(TemporalMean, self).__init__()

    def calculate(self, array):
       """Returns the window and the mean along the time axis"""
        return array.mean(axis=0).squeeze()
with gw.series(file_list, nodata=0, num_threads=4, window_size=(1024, 1024)) as src:
    src.apply(TemporalMean(), outfile='mean.tif', bands=2, gain=0.0001, num_workers=4)
@jgrss jgrss added the enhancement New feature or request label Jul 3, 2021
@jgrss
Copy link
Owner Author

jgrss commented Jul 3, 2021

See the series class here for more details.

@jgrss
Copy link
Owner Author

jgrss commented Jul 3, 2021

Custom classes must have the following structure:

class Custom(gw.TimeModule):

    def __init__(self):        
        super(Custom, self).__init__()

    def calculate(self, array):
        """
        Args:
            array: JAX DeviceArray shaped [time x bands x rows x columns]
        """
        # The returned array should have dimensions [output band count x rows x columns], where
        res = <>
        return res

Override the output data type, band count, and compression

class Custom(gw.TimeModule):

    def __init__(self):        

        super(Custom, self).__init__()

        self.dtype = 'uint16'
        self.count = 2
        self.compress = 'lzw'

    def calculate(self, array):
        """
        Args:
            array: JAX DeviceArray shaped [time x bands x rows x columns]
        """
        
        # If the array is shaped [20 x 1 x 100 x 100]
        res = array.mean(axis=0).squeeze()
        
        return res

Using the band dictionary

import geowombat as gw
import jax.numpy as jnp


class Custom(gw.TimeModule):

    def __init__(self):        
        super(Custom, self).__init__()

    def calculate(self, array):
        """
        Args:
            array: JAX DeviceArray shaped [time x bands x rows x columns]
        """

        s1 = (slice(0, None), slice(band_dict['red'], band_dict['red']+1), slice(0, None), slice(0, None))
        s2 = (slice(0, None), slice(band_dict['nir'], band_dict['nir']+1), slice(0, None), slice(0, None))

        ndvi = (array[s1] - array[s2]) / (array[s1] + array[s2])

        # Mean NDVI over time
        res = jnp.nanmean(ndvi, axis=0).squeeze()

        return res


with gw.series([...], band_names=['red', 'nir']) as src:
    src.apply(Custom(), 'outfile.tif', bands=[3, 4])

@jgrss
Copy link
Owner Author

jgrss commented Jul 4, 2021

I created a page for time series on GPUs here.

@mmann1123
Copy link
Collaborator

mmann1123 commented Jul 9, 2021 via email

@jgrss
Copy link
Owner Author

jgrss commented Sep 22, 2022

I will close this as this is now implemented. However, feel free to reopen.

@jgrss jgrss closed this as completed Sep 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants