In [None]:
# default_exp data

# Data
> This module contains functions to download and preprocess the data

In [None]:
#hide
from nbdev.export import notebook2script

In [None]:
#export
import ee
import os
import requests
import rasterio
import pandas as pd
import numpy as np
import zipfile
from IPython.core.debugger import set_trace
from pathlib import Path
from banet.geo import open_tif, merge, Region
from banet.geo import downsample

In [None]:
#export
class Region(Region):
    @property
    def transform(self):
        "Rasterio Affine transform of the region"
        return rasterio.transform.from_origin(self.bbox[0], self.bbox[-1], 
                                              self.pixel_size, self.pixel_size)
    
class RegionST(Region):
    "Defines a region in space and time with a name, a bounding box and the pixel size."
    def __init__(self, name:str, bbox:list, pixel_size:float, time_start:str=None,
                 time_end:str=None, time_freq:str='D', time_margin:int=0):
        self.name = name
        self.bbox = rasterio.coords.BoundingBox(*bbox) # left, bottom, right, top
        self.pixel_size = pixel_size
        self.time_start = pd.Timestamp(str(time_start))
        self.time_end = pd.Timestamp(str(time_end))
        self.time_margin = time_margin
        self.time_freq = time_freq

    @property
    def times(self):
        "Property that computes the date_range for the region."
        tstart = self.time_start - pd.Timedelta(days=self.time_margin)
        tend = self.time_end + pd.Timedelta(days=self.time_margin)
        return pd.date_range(tstart, tend, freq=self.time_freq)

    @classmethod
    def load(cls, file, time_start=None, time_end=None):
        "Loads region information from json file"
        with open(file, 'r') as f:
            args = json.load(f)
        if time_start is None:
            time_start = args['time_start']
        if time_end is None:
            time_end = args['time_end']
        return cls(args['name'], args['bbox'], args['pixel_size'],
                   time_start, time_end)
    
def extract_region(df_row, cls=Region):
    "Create Region object from a row of the metadata dataframe."
    if issubclass(cls, RegionST):
        return cls(df_row.event_id, df_row.bbox, df_row.pixel_size, 
                   df_row.time_start, df_row.time_end)
    elif issubclass(cls, Region):
        return cls(df_row.event_id, df_row.bbox, df_row.pixel_size)
    else: raise NotImplemented('cls must be one of the following [Region, RegionST]')

In [None]:
#export
def coords2bbox(lon, lat, pixel_size): 
    return [lon.min(), lat.min(), lon.max()+pixel_size, lat.max()+pixel_size]

def split_region(region:RegionST, size:int, cls=Region):
    lon, lat = region.coords()
    Nlon = (len(lon)//size)*size
    Nlat = (len(lat)//size)*size
    lons = [*lon[:Nlon].reshape(-1, size), lon[Nlon:][None]]
    lats = [*lat[:Nlat].reshape(-1, size), lat[Nlat:][None]]
    if len(lats[-1].reshape(-1)) == 0 and len(lons[-1].reshape(-1)) == 0:
        lons = lons[:-1]
        lats = lats[:-1]
    #lons = lon.reshape(-1, size)
    #lats = lat.reshape(-1, size)
    if issubclass(cls, RegionST):
        return [cls('', coords2bbox(ilon, ilat, region.pixel_size), 
                    pixel_size=region.pixel_size, time_start=region.time_start,
                    time_end=region.time_end, time_freq=region.time_freq,
                    time_margin=region.time_margin) for ilon in lons for ilat in lats]
    elif issubclass(cls, Region):
        return [cls('', coords2bbox(ilon, ilat, region.pixel_size), pixel_size=region.pixel_size) 
            for ilon in lons for ilat in lats]
    else: raise NotImplemented('cls must be one of the following [Region, RegionST]')
        
    return 
            
def merge_tifs(files:list, fname:str, delete=False):
    data, tfm = merge([open_tif(str(f)) for f in files])
    data = data.squeeze()
    fname = Path(files[0]).parent/fname
    profile = open_tif(str(files[0])).profile
    with rasterio.Env():
        height, width = data.shape
        profile.update(width=width, height=height, transform=tfm, compress='lzw')
        with rasterio.open(str(fname), 'w', **profile) as dst:
            dst.write(data, 1)
    if delete:
        for f in files: os.remove(f)

In [None]:
#export
def filter_region(image_collection:ee.ImageCollection, region:RegionST, times:tuple, bands=None):
    image_collection = image_collection.filterDate(times[0], times[1])
    geometry = ee.Geometry.Rectangle(region.bbox)
    image_collection = image_collection.filterBounds(geometry)
    if bands is not None:
        image_collection = image_collection.select(bands)
    return image_collection

def download_data(R:RegionST, times, products, bands, path_save, scale=10):
    ee.Initialize()
    path_save.mkdir(exist_ok=True, parents=True)
    if not ((path_save/f'download.{bands[0]}.tif').is_file() and 
           (path_save/f'download.{bands[1]}.tif').is_file() and
           (path_save/f'download.{bands[2]}.tif').is_file()):
        sR = [R] if R.shape[0] <= 32 else split_region(R, size=32, cls=RegionST)
        fsaves = []
        #for j, R in tqdm(enumerate(sR), total=len(sR)):
        for j, R in enumerate(sR):
            region = (f"[[{R.bbox.left}, {R.bbox.bottom}], [{R.bbox.right}, {R.bbox.bottom}], " +
                       f"[{R.bbox.right}, {R.bbox.top}], [{R.bbox.left}, {R.bbox.top}]]")

            if not ((path_save/f'download.{bands[0]}_{j}.tif').is_file() and 
                   (path_save/f'download.{bands[1]}_{j}.tif').is_file() and
                   (path_save/f'download.{bands[2]}_{j}.tif').is_file()):
                # Merge products to single image collection
                imCol = ee.ImageCollection(products[0])
                for i in range(1, len(products)):
                    imCol = imCol.merge(ee.ImageCollection(products[i]))
                im = filter_region(imCol, R, times=times, bands=bands).median()
                imCol = ee.ImageCollection([im])
                colList = imCol.toList(imCol.size())
                # Download each image
                for i in range(colList.size().getInfo()):
                    image = ee.Image(colList.get(i))
                    fname = 'download'
                    #fname = image.get('system:id').getInfo().split('/')[-1]
                    fnames_full = [f'{fname}.{b}.tif' for b in bands]
                    fnames_partial0 = [f'{fname}.{b}_{j}.tif' for b in bands]
                    fnames_full = all([(path_save/f).is_file() for f in fnames_full])
                    fnames_partial = all([(path_save/f).is_file() for f in fnames_partial0])
                    if not fnames_full:
                        fsaves.append([path_save/f for f in fnames_partial0])
                        if not fnames_partial:
                            zip_error = True
                            for i in range(10): # Try 10 times
                                if zip_error:
                                    try:
                                        url = image.getDownloadURL(
                                            {'scale': scale, 'crs': 'EPSG:4326', 
                                             'region': f'{region}'})
                                        r = requests.get(url)
                                        with open(str(path_save/'data.zip'), 'wb') as f:
                                            f.write(r.content)
                                        with zipfile.ZipFile(str(path_save/'data.zip'), 'r') as f:
                                            files = f.namelist()
                                            f.extractall(str(path_save))
                                        os.remove(str(path_save/'data.zip'))
                                        zip_error = False
                                    except:
                                        zip_error = True
                                        os.remove(str(path_save/'data.zip'))
                                        time.sleep(10)
                            if zip_error: raise Exception(f'Failed to process {url}')
                            for f in files:
                                f = path_save/f
                                os.rename(str(f), str(path_save/f'{f.stem}_{j}{f.suffix}'))
        # Merge files
        suffix = '.tif'
        files = path_save.ls(include=[suffix])
        #files = np.unique(fsaves) 
        files = [o.stem for o in files]
        ref = np.unique(['_'.join(o.split('_')[:-1]) 
                         for o in files if len(o.split('_')[-1]) < 6])
        ids = np.unique([int(o.split('_')[-1]) 
                         for o in files if len(o.split('_')[-1]) < 6])
        #file_groups = [[path_save/f'{r}_{i}{suffix}' for i in ids] for r in ref] 
        file_groups = [[path_save/f'{r}_{i}{suffix}' for i in ids 
                    if f'{r}_{i}' in files] for r in ref] 
        for fs in file_groups:
            if len(fs) < 500:
                fsave = '_'.join(fs[0].stem.split('_')[:-1]) + suffix
                merge_tifs(fs, fsave, delete=True)
            else:
                fs_break = np.array(fs)[:(len(fs)//500)*500].reshape(len(fs)//500,-1).tolist()
                if len(fs[(len(fs)//500)*500:]) > 0:
                    fs_break.append(fs[(len(fs)//500)*500:])
                for fsi, fs2 in enumerate(fs_break):
                    fsave = '_'.join(fs2[0].stem.split('_')[:-1]) + f'_break{fsi}' + suffix
                    merge_tifs(fs2, fsave, delete=True)

        files = path_save.ls(include=[suffix, '_break'])
        files = [o.stem for o in files]
        ref = np.unique(['_'.join(o.split('_')[:-1]) 
                         for o in files if len(o.split('_')[-1]) < 11])
        ids = np.unique([o.split('_')[-1]
                         for o in files if len(o.split('_')[-1]) < 11])
        #file_groups = [[path_save/f'{r}_{i}{suffix}' for i in ids] for r in ref] 
        file_groups = [[path_save/f'{r}_{i}{suffix}' for i in ids 
                    if f'{r}_{i}' in files] for r in ref] 
        for fs in file_groups:
            fsave = '_'.join(fs[0].stem.split('_')[:-1]) + suffix
            merge_tifs(fs, fsave, delete=True)

Download data for any region example:
```python
R = RegionST('test_region', [-9.0,39.95,-8.9,40.05], 0.001, time_start='2020-07-01', time_end='2020-07-15')
R.time_margin=1
products = ["COPERNICUS/S2"]
bands = ['B4', 'B8', 'B12']
path  = Path('temp')
before = (R.times[0]-pd.Timedelta(days=60), R.times[0])
after  = (R.times[-1], R.times[-1]+pd.Timedelta(days=60))
for mode, time_window in zip(['before', 'after'], [before, after]):
    path_save = path/R.name/mode
    download_data(R, time_window, products, bands, path_save)
```

In [None]:
#export
def get_event_data(event_id, year, coarse_mask_file, path=Path('.'), 
                   coarse_mask_doy_layer=1, products=['COPERNICUS/S2'], 
                   bands=['B4', 'B8', 'B12'], scale_factor=1e-4, composite_days=[60,60]):
    rst_ba100 = open_tif(coarse_mask_file)
    doy_start = rst_ba100.read(coarse_mask_doy_layer).min()
    doy_end   = rst_ba100.read(coarse_mask_doy_layer).max()
    time_start = pd.Timestamp(f'{year}-01-01') + pd.Timedelta(days=doy_start-1)
    time_end   = pd.Timestamp(f'{year}-01-01') + pd.Timedelta(days=doy_end-1)
    R = RegionST(event_id, list(rst_ba100.bounds), rst_ba100.transform[0], 
                 time_start=time_start, time_end=time_end, time_margin=1)
    before = (R.times[0]-pd.Timedelta(days=composite_days[0]), R.times[0])
    after  = (R.times[-1], R.times[-1]+pd.Timedelta(days=composite_days[1]))
    for mode, time_window in zip(['before', 'after'], [before, after]):
        path_save = path/R.name/mode
        download_data(R, time_window, products, bands, path_save)

    rst_ba100 = rst_ba100.read(coarse_mask_doy_layer)
    s10before_files = np.array((path/R.name/'before').ls(exclude=['.xml']))[[1,2,0]].tolist()
    s10after_files = np.array((path/R.name/'after').ls(exclude=['.xml']))[[1,2,0]].tolist()
    transform = rasterio.open(str(s10before_files[0])).transform
    crs = rasterio.open(str(s10before_files[0])).crs
    rst_s10before = np.concatenate(
        [rasterio.open(str(f)).read() for f in s10before_files]).astype(np.float16)*scale_factor
    rst_s10after = np.concatenate(
        [rasterio.open(str(f)).read() for f in s10after_files]).astype(np.float16)*scale_factor
    rst_ba100 = downsample(rst_ba100, src_tfm=R.transform, dst_tfm=transform, 
                          dst_shape=(1, *rst_s10before.shape[-2:]), resampling='bilinear').astype(np.float32)
    im = np.concatenate([rst_s10before, rst_s10after, rst_ba100], axis=0).transpose(1,2,0)
    return im, transform, crs

In [None]:
%%time
im, transform, crs = get_event_data('temp', 2020, 'temp/banet100m.tif')
im.shape, transform, crs

CPU times: user 449 ms, sys: 44.8 ms, total: 493 ms
Wall time: 4.84 s


((892, 881, 7),
 Affine(8.983152841195215e-05, 0.0, 6.579979793118671,
        0.0, -8.983152841195215e-05, 43.21507371008099),
 CRS.from_epsg(4326))

In [None]:
#hide 
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.ipynb.
Converted 02_models.ipynb.
Converted 04_predict.ipynb.
Converted 05_cli.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
