## COG

## Initialise COG

### Load packages

In [3]:
import os
import sys
import html
import requests
import gdal
import rasterio
import geopandas as gpd
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from datetime import datetime
from lxml import etree
from tempfile import NamedTemporaryFile

sys.path.append('../../../Scripts')
from dea_dask import create_local_dask_cluster

sys.path.append('../../shared')
import satfetcher

### Set up a dask cluster

In [4]:
# initialise the cluster
create_local_dask_cluster()

0,1
Client  Scheduler: tcp://127.0.0.1:42199  Dashboard: /user/lewis/proxy/8787/status,Cluster  Workers: 1  Cores: 2  Memory: 13.11 GB


## Load study area polygon

In [None]:
# read study area
#gdf = gpd.read_file('yandisa.geojson')

## Set STAC Search parameters

In [None]:
# get satellite collection on dea. todo get from user in arcgis, sentinel 2
collections = [
    'ga_ls5t_ard_3', 
    'ga_ls7e_ard_3',
    'ga_ls8c_ard_3'
]

# set required bands
bands = [
    'oa_fmask',
    'nbart_blue', 
    'nbart_green', 
    'nbart_red', 
    'nbart_nir',
    'nbart_swir_1',
    'nbart_swir_2'
]

# get satellite collection date range, convert to stac. todo get from user in arcgis
start_dt, end_dt = '1990-01-01', '1995-12-31'

# convert datetime strings to datetimes
start_dt = datetime.strptime(start_dt, "%Y-%m-%d").strftime("%Y-%m-%dT00:00:00Z")
end_dt = datetime.strptime(end_dt, "%Y-%m-%d").strftime("%Y-%m-%dT00:00:00Z")

# bring it all together for a query
query = {
    'collections': collections,
    'datetime': '{0}/{1}'.format(start_dt, end_dt),
    'bbox': gdf.bounds.values[0].tolist(),
    #'query': {'eo:cloud_cover': {'lt': 5}},
    'limit': 1000
}

## Fetch DEA Public Data via STAC Search

In [None]:
# set stac endpoint
search_endpoint = 'https://explorer.sandbox.dea.ga.gov.au/stac/search'

# send and get request for stac json using 
stac_response = requests.post(search_endpoint, json=query)

# check for response empty errors, convert to json if so
if stac_response.ok:
    stac_response = stac_response.json()
    num_items = len(stac_response.get('features'))
    print('Found {0} satellite scenes in total.'.format(num_items))
else:
    raise ValueError('Could not connect to DEA STAC SEARCH endpoint.')

## Iterate STAC response and remove cloud cover

In [None]:
# set max cloud cover (0 - 100)
max_cloud = 25

# get num of all stac scenes
num_all_items = len(stac_response.get('features'))

feat_list = []
for feat in stac_response.get('features'):
    if max_cloud > float(feat.get('properties').get('eo:cloud_cover')):
        feat_list.append(feat)
        
# count cloud less scenes and compare
if feat_list:
    num_clean_items = len(feat_list)
    print('Removed {0} satellite scenes due to clouds.'.format(num_all_items - num_clean_items))
    print('Total of {0} satellite scenes remaining.'.format(num_clean_items))

## Set VRT Dataset and Raster generators

In [None]:
def build_vrt_dataset(x_size, y_size, srs, transform):
    """
    """
    
    # create vrt dataset xml element
    xml_vrt = etree.Element('VRTDataset', 
                            rasterXSize="{}".format(x_size),
                            rasterYSize="{}".format(y_size))

    # create srs xml element
    xml_srs = etree.Element('SRS',
                            dataAxisToSRSAxisMapping="1,2") # todo hardcoded
    xml_srs.text = srs
    xml_vrt.append(xml_srs)

    # create geotransform xml element
    xml_transform = etree.Element('GeoTransform')
    xml_transform.text = transform
    xml_vrt.append(xml_transform)

    # create metadata xml element
    #xml_meta = etree.Element('Metadata')
    
    # add test metadata item
    #xml_meta_dt = etree.Element('MDI', key="TIFFTAG_DATETIME")
    #xml_meta_dt.text = "{}".format("2012-01-01")
    #xml_meta.append(xml_meta_dt)
    
    # add meta to vrt xml
    #xml_vrt.append(xml_meta)
    
    return xml_vrt

In [None]:
def build_vrt_raster(x_size, y_size, dtype, dt, url, relative_to, band_num, nodata):
    """
    """
    
    # check types
    
    # create vrt dataset xml element
    xml_rast = etree.Element('VRTRasterBand', 
                             dataType="{}".format(dtype),
                             band="{}".format(band_num))
    
    # create nodata xml element
    xml_nodata = etree.Element('NoDataValue')
    xml_nodata.text = '{}'.format(nodata)
    xml_rast.append(xml_nodata)
    
    # create raster description xml element
    xml_desc = etree.Element('Description')
    xml_desc.text = '{}'.format(dt)
    xml_rast.append(xml_desc)
      
    # create simple source xml element
    xml_complex = etree.Element('ComplexSource')

    # create SourceFilename xml element
    xml_filename = etree.Element('SourceFilename',
                                 relativeToVRT="{}".format(relative_to))
    xml_filename.text = '/vsicurl/{}'.format(url)
    xml_complex.append(xml_filename)

    # create simple source xml element
    xml_band = etree.Element('SourceBand')
    xml_band.text = '{}'.format(band_num)  # was 1, as official vrtbuilder does it, but this helps id
    xml_complex.append(xml_band)
    
    # create SourceProperties xml element
    xml_props = etree.Element('SourceProperties',
                              rasterXSize="{}".format(x_size),
                              rasterYSize="{}".format(y_size),
                              DataType="{}".format(dtype),
                              BlockXSize="{}".format(512),
                              BlockYSize="{}".format(512))
    xml_complex.append(xml_props)

    # create source rect xml element
    xml_src_rect = etree.Element('SrcRect',
                                 xOff="{}".format(0),
                                 yOff="{}".format(0),
                                 xSize="{}".format(x_size),
                                 ySize="{}".format(y_size))
    xml_complex.append(xml_src_rect)
    
    # create destination rect xml element
    xml_dst_rect = etree.Element('DstRect',
                                 xOff="{}".format(0),
                                 yOff="{}".format(0),
                                 xSize="{}".format(x_size),
                                 ySize="{}".format(y_size))
    xml_complex.append(xml_dst_rect)
    
    # create nodata xml element
    xml_src_nodata = etree.Element('NODATA')
    xml_src_nodata.text = '{}'.format(nodata)
    xml_complex.append(xml_src_nodata)
        
    # add simple element to raster element as child
    xml_rast.append(xml_complex)
    
    return xml_rast

## Fill VRT temmplate with values

In [None]:
# todo checks, meta
def generate_vrt(feat_list, band=None):
    """
    band : list, str
        Can be a list or string of name of band(s) required.
    """
        
    # check if band provided, if so and is str, make list
    if band is None:
        bands = []
    elif not isinstance(band, list):
        bands = [band]
    else:
        bands = band
        
    # check if bands in list
    allowed_bands = [
        'nbart_blue', 
        'nbart_green',
        'nbart_red',
        'nbart_nir',
        'nbart_swir_1',
        'nbart_swir_2',
        'oa_mask'
    ]
    
    # ensure requested bands allowed
    for b in bands:
        if b not in allowed_bands:
            raise ValueError('Requested an unsupported band.')
            
    # check features type, length
    if not isinstance(feat_list, list):
        raise TypeError('Features must be a list of xml objects.')
    elif not len(feat_list) > 0:
        raise ValueError('No features provided.')

    # set list vrt of each scene
    vrt_list = []

    # iter stac scenes, get metadata, insert bands into vrt template
    for feat_idx, feat in enumerate(feat_list, start=1):

        # get scene identity and properties
        f_id = feat.get('id')
        f_props = feat.get('properties')

        # get scene-level date
        f_dt = f_props.get('datetime')

        # get scene-level x, y parameters
        f_x_size = f_props.get('proj:shape')[1]
        f_y_size = f_props.get('proj:shape')[0]

        # get scene-level epsg src as wkt
        f_srs = rasterio.crs.CRS.from_epsg(f_props.get('proj:epsg')).wkt

        # get scene-level transform
        aff = rasterio.transform.Affine(*f_props.get('proj:transform')[0:6])
        f_transform = ', '.join(str(p) for p in rasterio.transform.Affine.to_gdal(aff))

        # generate vrt dataset 
        vrt_xml = build_vrt_dataset(x_size=f_x_size, 
                                    y_size=f_y_size, 
                                    srs=f_srs, 
                                    transform=f_transform)

        # iterate bands and add to vrt if exists
        band_idx = 1
        for band in bands:
            if band in feat.get('assets'):

                # get asset
                asset = feat.get('assets').get(band)

                # set dtype to in16 unless mask
                a_dtype = 'Int8' if band == 'oa_mask' else 'Int16'

                # get asset raster x, y sizes
                a_x_size = asset.get('proj:shape')[1]
                a_y_size = asset.get('proj:shape')[0]

                # get raster url, replace s3 with https
                a_url = asset.get('href')
                a_url = a_url.replace('s3://dea-public-data', 'https://data.dea.ga.gov.au')

                # get epsg as wkt
                #srs = rasterio.crs.CRS.from_epsg(f_epsg).wkt

                # get transform (six params) as string
                aff = rasterio.transform.Affine(*asset.get('proj:transform')[0:6])
                a_transform = ', '.join(str(p) for p in rasterio.transform.Affine.to_gdal(aff))

                # get nodata value
                a_nodata = -999

                # build raster xml
                rast_xml = build_vrt_raster(x_size=a_x_size, 
                                            y_size=a_y_size, 
                                            dtype=a_dtype, 
                                            dt=f_dt,
                                            url=a_url, 
                                            relative_to=1, 
                                            band_num=band_idx, 
                                            nodata=a_nodata)

                # add raster xml to vrt dataset xml as child
                vrt_xml.append(rast_xml)

                # increase band index
                band_idx += 1

        # decode to utf-8 string and append to vrt list
        vrt_str = etree.tostring(vrt_xml).decode('utf-8')
        vrt_list.append(vrt_str)
        
    return vrt_list

In [None]:
# create list of bands needed
wanted_bands = [
    'nbart_blue', 
    'nbart_green',
    'nbart_red',
    'nbart_nir',
    'nbart_swir_1',
    'nbart_swir_1',
    'oa_mask'
]

# build datetimes and vrts for each band
#for band in wanted_bands:

# todo, iterate this via list above
vrt_blue = generate_vrt(feat_list=feat_list, band='nbart_blue')
vrt_green = generate_vrt(feat_list=feat_list, band='nbart_green')
vrt_red = generate_vrt(feat_list=feat_list, band='nbart_red')
vrt_nir = generate_vrt(feat_list=feat_list, band='nbart_nir')
vrt_swir_1 = generate_vrt(feat_list=feat_list, band='nbart_swir_1')
vrt_swir_2 = generate_vrt(feat_list=feat_list, band='nbart_swir_2')
vrt_mask = generate_vrt(feat_list=feat_list, band='oa_mask')

## Build an completed in-memory VRT file

In [None]:
# checks, meta
def create_vrt_file(vrt_files):
    """
    """
    
    # checks
    
    # load up a temp named file and create vrt
    with NamedTemporaryFile() as tmp:

        # set vrt options
        vrt_opts = gdal.BuildVRTOptions(separate=True,
                                        #bandList=[1],
                                        #outputBounds=boundingbox,
                                        #resampleAlg='bilinear',
                                        #resolution='user',
                                        #xRes=30.0,
                                        #yRes=30.0,
                                        #outputSRS=rasterio.crs.CRS.from_epsg(3577).wkt
                                        #targetAlignedPixels=True
                                       )
        
        # build vrt, close it (to create it)
        vrt_out = gdal.BuildVRT(tmp.name, vrt_files, options=vrt_opts)
        vrt_out = None

        # warp and translate funcs
        # todo: MAY NEED

        # read it in to memory and decode it
        vrt = tmp.read().decode("utf-8")
        return vrt

In [None]:
# small subset of raster in utm 50N
#bb = [683100.0, -2542470.0, 686070.0, -2539500.0]

# todo improve this code

# create vrts
vrt_blue_out = create_vrt_file(vrt_files=vrt_blue)
vrt_green_out = create_vrt_file(vrt_files=vrt_green)
vrt_red_out = create_vrt_file(vrt_files=vrt_red)
vrt_nir_out = create_vrt_file(vrt_files=vrt_nir)
vrt_swir_1_out = create_vrt_file(vrt_files=vrt_swir_1)
vrt_swir_2_out = create_vrt_file(vrt_files=vrt_swir_2)
vrt_mask_out = create_vrt_file(vrt_files=vrt_mask)

## Parse datetime strings into map

In [None]:
def parse_datetimes(vrt_string):
    
    # convert html tags back out
    clean_elem = html.unescape(vrt_blue_out)

    # convert string to etree elements
    root = etree.fromstring(clean_elem)

    # pull descriptions out to get date times
    elem_desc = root.findall('.//Description')

    # iterate elements and pull description text
    dt_map = {}
    for i, e in enumerate(elem_desc, start=1):
        dt_map[i] = e.text
        
    return dt_map

In [None]:
# create vrts
dt_blue = parse_datetimes(vrt_string=vrt_blue_out)
dt_green = parse_datetimes(vrt_string=vrt_green_out)
dt_red = parse_datetimes(vrt_string=vrt_red_out)
dt_nir = parse_datetimes(vrt_string=vrt_nir_out)
dt_swir_1 = parse_datetimes(vrt_string=vrt_swir_1_out)
dt_swir_2 = parse_datetimes(vrt_string=vrt_swir_2_out)
dt_mask = parse_datetimes(vrt_string=vrt_mask_out)

# check if lengths are all same


## Convert to chunked dataset

In [None]:
def build_xr_dataset(vrt_file, band_name):
    
    # setup chunks
    chunks = {'band': 1, 'x': 'auto', 'y': 'auto'}
    
    # load xr as data array
    ds = xr.open_rasterio(vrt_file, chunks=chunks)
    
    # rename default band label to time
    ds = ds.rename({'band': 'time'})
    
    # convert to dataset
    ds = ds.to_dataset(name=band_name, promote_attrs=True)
    
    # subset to coords, bb todo fix this up
    ds = ds.isel(x=slice(4000, 5000), y=slice(3000, 4000))
    
    return ds

In [None]:
# create datasets
ds_blue = build_xr_dataset(vrt_file=vrt_blue_out, band_name='nbart_blue')

In [None]:
# replace datetime
def replace_datetimes(ds, dt):
    
    # replace timezone and convert numpy
    dt_dict = {}
    for k, v in dt_blue.items():
        dt_dict[k] = np.datetime64(v.replace('Z', ''))
    
    # remap
    ds['time'] = [dt_dict[i] for i in ds['time'].values.tolist()]
    return ds.sortby('time')
    
ds_blue = replace_datetimes(ds_blue, dt_blue)

In [None]:
# compute
%time ds_blue = ds_blue.compute()
ds_blue

In [None]:
# combine all dask datasets into one
#xr.merge([ds_blue, ds_green, ds_red])

## Test download times

### Try raw, without dask

In [None]:
# speed testing without dask distributed
%time ds = ds.compute()

### Try raw, without dask but with threading

In [None]:
# speed testing without dask distributed
%time ds = ds.compute(scheduler='threads')

### Try raw, without dask but with processes

In [None]:
# speed testing without dask distributed
%time ds = ds.compute(scheduler='processes')

### Try dask, with distributed scheduler

In [None]:
import dask
from dask.distributed import Client
client = Client(processes=True)
client

In [None]:
# about 47 secs with processes=false, 21 secs when True
%time ds = ds.compute()

### Try dask data arrays split and futures used

In [None]:
import concurrent.futures 

# create compute func
def compute_da(da):
    return da.compute()

In [None]:
# split ds into seperate das
da_list = []
for dt in ds['time']:
    da = ds.sel(time=dt)
    da_list.append(da)
    
# try parallel load of all bands
num_cores = 2
with concurrent.futures.ThreadPoolExecutor(num_cores) as executor:
    %time da_list = list(executor.map(compute_da, da_list))
    
ds = xr.concat(da_list, dim='time')

## Working

## Use this to auto gen vrt to test

In [None]:
urls = [
    '/vsicurl/https://data.dea.ga.gov.au/baseline/ga_ls5t_ard_3/112/076/1990/02/09/ga_ls5t_nbart_3-0-0_112076_1990-02-09_final_band02.tif',
    '/vsicurl/https://data.dea.ga.gov.au/baseline/ga_ls5t_ard_3/112/076/1990/02/09/ga_ls5t_nbart_3-0-0_112076_1990-02-09_final_band03.tif',
    '/vsicurl/https://data.dea.ga.gov.au/baseline/ga_ls5t_ard_3/112/076/1990/02/09/ga_ls5t_nbart_3-0-0_112076_1990-02-09_final_band04.tif',
    '/vsicurl/https://data.dea.ga.gov.au/baseline/ga_ls5t_ard_3/112/076/1990/02/09/ga_ls5t_nbart_3-0-0_112076_1990-02-09_final_band05.tif',
    '/vsicurl/https://data.dea.ga.gov.au/baseline/ga_ls5t_ard_3/112/076/1990/02/09/ga_ls5t_nbart_3-0-0_112076_1990-02-09_final_band06.tif'
]

with NamedTemporaryFile() as tmp:
    out = gdal.BuildVRT('hey.vrt', 
                        urls, urls), 
                        #xRes=10.0,
                        #yRes=10.0,
                        #outputSRS=rasterio.crs.CRS.from_epsg(3577).wkt,
                        #outputBounds=bb, 
                        #resolution='highest', 
                        #resampleAlg='near',
                        separate=True,
                        )
    out = None

    v = tmp.read().decode("utf-8")

#xr.open_rasterio(v).to_dataset(dim='band')