In [1]:
%matplotlib inline

import datacube
from datacube.storage import masking

dc = datacube.Datacube(app='test_animation')



In [2]:
# Near Cape Town South Africa

area_name = 'south_africa_monthly'
query = {'lat': (-33.72, -33.62), 'lon':(19.48, 19.60),
        'output_crs': 'EPSG:32734',
        'resolution': (-10,10),
        }
rgb_bands = ['red','green','blue']
band_index = 'ndvi'
clear_frac = 0.9
band_index_cmap = 'RdYlGn'
band_index_min, band_index_max = 0.1, 0.7
sensors = ['s2']
mask_cloud = False
percentile_stretch = (0.01, 0.98)
resample = '1M'
rolling = False
interval = 1000

In [3]:
area_name = 'mali_wetland'

lat, lon = 15.323201, -3.798044
lat, lon = 15.6, -3.8
buffer_y = 0.2
buffer_x = 0.25

query = {'lat': (lat-buffer_y, lat+buffer_y), 
         'lon':(lon-buffer_x, lon+buffer_x),
        'output_crs': 'EPSG:32630',
        'resolution': (-10,10),
        'cloud_cover': (0, 5),
        }
rgb_bands = ['swir_1','nir','green']
band_index = None #'ndvi'
clear_frac = 0.95
band_index_cmap = 'RdYlGn'
band_index_min, band_index_max = 0.1, 0.7
sensors = ['s2']
mask_cloud = False
percentile_stretch = (0.02, 0.98)
resample = None
rolling = False
interval = 800
best_per_month = True

In [4]:
import xarray as xr
import numpy as np
from skimage.morphology import disk, binary_dilation

bad = []

def load_and_combine(query, base_bands =['red','green','blue'], band_index = None, mask_cloud=True, best_per_month = True,
                     dilation=0, clear_frac=0.8, sensors=[]):
    bands = base_bands.copy()
    if not band_index is None:
        if band_index.lower() == 'mndwi':
            if not 'green' in bands: bands += ['green']
            if not 'swir_1' in bands: bands += ['swir_1']
        if band_index.lower() == 'ndbi':
            if not 'nir' in bands: bands += ['nir']
            if not 'swir_1' in bands: bands += ['swir_1']
        if band_index.lower() == 'ndvi':
            if not 'nir' in bands: bands += ['nir']
            if not 'red' in bands: bands += ['red']
    
    datasets =[]
    for sensor in sensors:
        if sensor == 's2':
            product = 's2_l2a'
            if not 'SCL' in bands: bands += ["SCL"]
        else:
            product = 'ls%s_usgs_sr_scene'%sensor
        
        
        dss = dc.find_datasets(
            product=product, measurements=bands,**query)
        #print("Found %d datasets for Landsat %d."%(len(dss), sensor))
        for bad_id in bad: dss = [dataset for dataset in dss if not bad_id in dataset.uris[0]]
        print("Found %d good datasets for sensor %s."%(len(dss), sensor))
        if len(dss)==0:continue

        if (len(dss)>100 and (not 'time' in query)) or best_per_month:
            ds_set =[]
            time_slice = np.datetime64('2016-12')
            while time_slice < np.datetime64('now'):
                time_slice += np.timedelta64(1,'M')
                #for time_slice in [('2018-01-01','2018-06-30'),('2018-07-01','2018-12-31'),('2019-01-01','2019-06-30'),('2019-07-01','2019-12-31')]: #,('2020-01-01','2020-12-31')]:
                sub_query = query.copy()
                sub_query['time'] = str(time_slice)
                dss = dc.find_datasets(product=product, measurements=bands,**sub_query)
                for bad_id in bad: dss = [dataset for dataset in dss if not bad_id in dataset.uris[0]]
                print(f"Found {len(dss)} good datasets for sensor {sensor} for {time_slice}")
                if len(dss)==0:continue
                ds = dc.load(product=product, group_by='solar_day', datasets=dss, measurements=bands,**sub_query)
                
                if product =='s2_l2a':
                    mask = good = ds.SCL.isin([2, 4,5,6,7,10,11])
                else:
                    pq = dc.load(product=product, group_by='solar_day', datasets=dss, measurements=['pixel_qa'],**sub_query)
                    mask = masking.make_mask(pq.pixel_qa, cloud='no_cloud', cloud_shadow='no_cloud_shadow', nodata=False)
                
                if 'blue' in ds: mask = mask & (ds.blue<1000)
                if dilation >0:
                    mask = ~(~mask).groupby('time').apply(binary_dilation,selem=disk(10))
                good_frac = mask.mean(['x','y'])
                #print(good_frac.sel(time=(good_frac>=clear_frac)))
                if best_per_month:
                    good = good_frac == good_frac.max()
                else:
                    good = good_frac >= lear_frac
                if good.sum()==0: continue
                ds_clean = ds.isel(time=good)
                if mask_cloud:
                    ds_clean = ds_clean.where(mask.isel(time=good))
                ds_clean = ds_clean.where(ds_clean>0,) 
                ds_clean['mask'] = mask.isel(time=good)
                ds_set.append(ds_clean.sortby('time'))
                print("Found %d clear observations for sensor %s."%(len(ds_clean.time), sensor))  
    
            ds_clean = xr.concat(ds_set, dim='time')
                    
        else:
            ds = dc.load(product=product, group_by='solar_day', datasets=dss, measurements=bands,**query)
            if product =='s2_l2a':
                mask = good = ds.SCL.isin([2, 4,5,6,7,11])
            else:
                pq = dc.load(product=product, group_by='solar_day', datasets=dss, measurements=['pixel_qa'],**query)
                mask = masking.make_mask(pq.pixel_qa, cloud='no_cloud', cloud_shadow='no_cloud_shadow', nodata=False)
            if 'blue' in ds: mask = mask & (ds.blue<1000)
            if dilation >0:
                mask = ~(~mask).groupby('time').apply(binary_dilation,selem=disk(10))
            good = mask.mean(['x','y']) >= clear_frac
            if good.sum()==0: continue
            ds_clean = ds.isel(time=good)
            if mask_cloud:
                    ds_clean = ds_clean.where(mask.isel(time=good))
            ds_clean = ds_clean.where(ds_clean>0,) 
            ds_clean['mask'] = mask.isel(time=good)
            print("Found %d clear observations for sensor %s."%(len(ds_clean.time), sensor))
        
        if not band_index is None:
            if band_index.lower() == 'mndwi': 
                ds_clean[band_index] = (ds_clean.green-ds_clean.swir_1)/(ds_clean.green+ds_clean.swir_1)
            if band_index.lower() == 'ndbi':
                ds_clean[band_index] = (ds_clean.swir_1-ds_clean.nir)/(ds_clean.swir_1+ds_clean.nir)
            if band_index.lower() == 'ndvi':
                ds_clean[band_index] = (ds_clean.nir-ds_clean.red)/(ds_clean.nir+ds_clean.red)
        #ds_clean[band_index] = ds_clean[band_index].where(ds_clean.mask)
        #datasets = xr.concat([datasets, ds_clean], dim ='time')
        datasets.append(ds_clean)
    
    if len(datasets)==1:
        return datasets[0]
    elif len(datasets)>1:
        combined_ds = xr.concat(datasets, dim='time')
        combined_ds = combined_ds.sortby('time')
        return combined_ds
    else: return None
    

In [5]:
ds = load_and_combine(query, base_bands = rgb_bands, band_index = band_index, clear_frac=clear_frac, best_per_month=best_per_month,
                      mask_cloud=mask_cloud, sensors=sensors)

Found 197 good datasets for sensor s2.
Found 4 good datasets for sensor s2 for 2017-01
Found 1 clear observations for sensor s2.
Found 2 good datasets for sensor s2 for 2017-02
Found 1 clear observations for sensor s2.
Found 2 good datasets for sensor s2 for 2017-03
Found 1 clear observations for sensor s2.
Found 4 good datasets for sensor s2 for 2017-04
Found 1 clear observations for sensor s2.
Found 2 good datasets for sensor s2 for 2017-05
Found 1 clear observations for sensor s2.
Found 3 good datasets for sensor s2 for 2017-06
Found 1 clear observations for sensor s2.
Found 3 good datasets for sensor s2 for 2017-07
Found 1 clear observations for sensor s2.
Found 4 good datasets for sensor s2 for 2017-08
Found 1 clear observations for sensor s2.
Found 7 good datasets for sensor s2 for 2017-09
Found 1 clear observations for sensor s2.
Found 5 good datasets for sensor s2 for 2017-10
Found 1 clear observations for sensor s2.
Found 7 good datasets for sensor s2 for 2017-11
Found 1 clear

In [6]:
ds.time.values

array(['2017-01-03T10:53:18.000000000', '2017-02-02T10:45:36.000000000',
       '2017-03-04T10:55:24.000000000', '2017-04-23T10:54:45.000000000',
       '2017-05-03T10:50:41.000000000', '2017-06-02T10:54:35.000000000',
       '2017-07-27T10:46:37.000000000', '2017-08-01T10:45:59.000000000',
       '2017-09-20T10:46:22.000000000', '2017-10-10T10:49:47.000000000',
       '2017-11-04T10:44:39.000000000', '2017-12-29T10:51:29.000000000',
       '2018-01-13T10:49:17.000000000', '2018-02-17T10:54:12.000000000',
       '2018-03-09T10:51:36.000000000', '2018-04-13T10:44:43.000000000',
       '2018-05-08T10:50:35.000000000', '2018-06-22T10:50:27.000000000',
       '2018-07-07T10:56:30.000000000', '2018-08-16T10:44:19.000000000',
       '2018-09-20T10:47:12.000000000', '2018-10-10T10:50:53.000000000',
       '2018-11-04T10:55:45.000000000', '2018-12-04T10:56:37.000000000',
       '2019-01-28T10:56:46.000000000', '2019-02-12T10:56:42.000000000',
       '2019-03-09T10:57:09.000000000', '2019-04-18

In [7]:
#for a, b in zip(ds.time.values,ds.mask.groupby('time').mean().values):print(a,b)

In [8]:
#if area_name == 'burundi':
#    good = ds[band_index].groupby('time').max().values >0.2
#    ds = ds.isel(time=good)

In [9]:
if not best_per_month:
    ds.resample(time='1M').last().dropna('time', how='all')[rgb_bands].to_array().plot.imshow(col='time',col_wrap=4, robust=True);

In [10]:
#ds = ds[rgb_bands+[band_index]]
if resample:
    if not rolling:
        combined_ds = ds.resample(time=resample).median().dropna('time', how='all')
    else:
        combined_ds = ds.resample(time=resample).median().rolling(time=3, center=True, min_periods=1).median().dropna('time', how='all')
else:
    combined_ds = ds#.resample(time='1M').last().dropna('time', how='all')

In [None]:
combined_ds[rgb_bands].to_array().plot.imshow(col='time',col_wrap=4, robust=True);

In [None]:
if band_index:
    combined_ds[band_index].plot.imshow(col='time',col_wrap=4,vmin=0.1, vmax=0.6, cmap=band_index_cmap, add_colorbar=False);

In [None]:
if area_name == 'st_louis':
    combined_ds = combined_ds.isel(time=slice(28, len(combined_ds.time)))

In [None]:
#import sys
#sys.path.append("../Scripts")
#from deafrica_plotting import animated_timeseries, animated_doubletimeseries

from DEAPlotting import animated_timeseries, animated_doubletimeseries


if resample: date_type = 'M8[Y]'
else: date_type = 'M8[M]'

if band_index:
    animated_doubletimeseries(combined_ds[rgb_bands], combined_ds[[band_index]], '%s_with_%s.gif'%(area_name, band_index), 
                              width_pixels=1500, interval=interval, 
                              bands1=rgb_bands, bands2=[band_index],                               
                              percentile_stretch1 = percentile_stretch, percentile_stretch2 = percentile_stretch,
                              vmin_2 = band_index_min, vmax_2 = band_index_max,
                              image_proc_func1=None, image_proc_func2=None,
                              title1=combined_ds.time.values.astype(date_type), title2=band_index.upper(), #combined_ds.time.values.astype('M8[Y]'),
                              show_date1=False, show_date2=False, animation_options = {'repeat_delay': interval*2},
                              annotation_kwargs1={}, annotation_kwargs2={},
                              onebandplot_cbar1=True, onebandplot_cbar2=True,
                              onebandplot_kwargs1={}, onebandplot_kwargs2={'cmap':band_index_cmap},
                              shapefile_path1=None, shapefile_path2=None,
                              shapefile_kwargs1={}, shapefile_kwargs2={},
                              time_dim1 = 'time', x_dim1 = 'x', y_dim1 = 'y',
                              time_dim2 = 'time', x_dim2 = 'x', y_dim2 = 'y')

In [None]:
from DEAPlotting import animated_timeseries
animated_timeseries(combined_ds[rgb_bands], '%s_color.gif'%(area_name), 
                    width_pixels=len(combined_ds.x), interval=interval, 
                    bands=rgb_bands,                           
                    percentile_stretch = percentile_stretch, 
                    image_proc_func=None, 
                    title=combined_ds.time.values.astype(date_type), 
                    show_date=False, animation_options = {'repeat_delay': interval*2},
                    annotation_kwargs={}, 
                    onebandplot_cbar=True, 
                    onebandplot_kwargs={}, 
                    shapefile_path=None,
                    shapefile_kwargs={}, 
                    time_dim = 'time', x_dim = 'x', y_dim = 'y')