In [1]:
# Jupyter notebook related
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import geopandas as gpd
from dask import delayed

# from elogs import Elogs, ElogsTask

with open('../../../connstr_vegteam') as f:
    connect_str = f.read()
container_name = 'evoland'

locs_fn = "../../../locations_v2.csv"

In [3]:
from importlib.resources import files

In [4]:
import satio_pc.layers

In [5]:
satio_pc.__version__

'0.0.2'

In [9]:
s2grid = satio_pc.layers.load('s2grid_all')

In [10]:
s2grid.shape

(56686, 4)

# Cluster setup

In [3]:
# from dask_gateway import Gateway
# gateway = Gateway()


# # List the clusters and get the cluster report
# clusters_reports = gateway.list_clusters()

# # Get the first cluster report
# cluster_report = clusters_reports[0]

# # Connect to the cluster using the cluster report
# cluster = gateway.connect(cluster_report)

# # Get the client object from the cluster
# client = cluster.get_client()

In [4]:
# stop clusters
from dask_gateway import Gateway
gateway = Gateway()
clusters_reports = gateway.list_clusters()

clusters = [gateway.stop_cluster(c.name) for c in clusters_reports]
clusters

[]

In [5]:
# create and scale cluster
from dask.distributed import PipInstall, Client
import dask_gateway

cluster = dask_gateway.GatewayCluster()
client = cluster.get_client()

print(client.dashboard_link)

cluster.scale(100)

cluster

https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.0ffc7bb799164d69bcecffd6b0463d46/status


VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [13]:
# Once cluster is scaled, install satio_pc
satio_pc_url = "https://s3-eu-central-1.amazonaws.com/vito-worldcover-public/wheels/satio_pc-0.0.1-py3-none-any.whl"
plugin = PipInstall(packages=[satio_pc_url])
client.register_worker_plugin(plugin)

{'tls://10.244.119.128:35439': {'status': 'OK'},
 'tls://10.244.119.129:33653': {'status': 'OK'},
 'tls://10.244.119.130:45755': {'status': 'OK'},
 'tls://10.244.123.13:38377': {'status': 'OK'},
 'tls://10.244.123.14:34383': {'status': 'OK'},
 'tls://10.244.123.15:39137': {'status': 'OK'},
 'tls://10.244.123.16:33829': {'status': 'OK'},
 'tls://10.244.123.17:36433': {'status': 'OK'},
 'tls://10.244.123.18:40811': {'status': 'OK'},
 'tls://10.244.129.71:38339': {'status': 'OK'},
 'tls://10.244.129.72:34239': {'status': 'OK'},
 'tls://10.244.129.73:46675': {'status': 'OK'},
 'tls://10.244.129.74:45393': {'status': 'OK'},
 'tls://10.244.129.75:40117': {'status': 'OK'},
 'tls://10.244.129.76:43925': {'status': 'OK'},
 'tls://10.244.143.18:39145': {'status': 'OK'},
 'tls://10.244.143.19:44651': {'status': 'OK'},
 'tls://10.244.143.20:37485': {'status': 'OK'},
 'tls://10.244.143.21:34957': {'status': 'OK'},
 'tls://10.244.143.22:40501': {'status': 'OK'},
 'tls://10.244.143.23:45767': {'statu

In [None]:
# check logs
logs = client.get_worker_logs()

print(len(logs))

for worker, worker_logs in logs.items():
    print(f"Logs for worker {worker}:")
    for log in worker_logs:
        print(log)
    print()
    print('*'*100)

In [15]:
# shutdown cluster
# cluster.shutdown()

# Training data extraction

In [6]:
import os
from pathlib import Path

import numpy as np
import pandas as pd


settings = {
        "max_cloud_cover": 90,
        "composite": {"freq": 10,
                      "window": 20,
                      "mode": "median"},
        "mask": {"erode_r": 3,
                 "dilate_r": 13,
                 "snow_dilate_r": 3,
                 "max_invalid_ratio": 1,
                 "max_invalid_snow_cover": 0.9},
        "scl_valid_th": 0.1,
        "bands": ['B02', 'B03', 'B04', 'B08', 'B11', 'B12'],
        "indices": ["ndvi"],
        "percentiles": [10, 25, 50, 75, 90],
    }


def read_block(block):
    from satio_pc.reader import S2TileReader
    tile = block.tile
    epsg = block.epsg
    # bounds = loc.xmin, loc.ymin, loc.xmax, loc.ymax
    bounds = block.bounds

    bands = ['B08']

    reader = S2TileReader(tile,
                          start_date,
                          end_date,
                          max_cloud_cover)
    darr = reader.read(bounds, epsg, bands, max_workers=20)
    return darr


def read_bounds(tile,
                bounds,
                epsg,
                bands,
                start_date,
                end_date,
                max_cloud_cover,
                max_workers=20):
    from satio_pc.reader import S2TileReader
    reader = S2TileReader(tile,
                          start_date,
                          end_date,
                          max_cloud_cover)
    darr = reader.read(bounds, epsg, bands, max_workers=max_workers)
    return darr


In [15]:
patch_id, tile, epsg, xmin, ymin, xmax, ymax, year = args[0].tolist()

In [22]:
year = int(year)
epsg = int(epsg)
xmin = float(xmin)
ymin = float(ymin)
xmax = float(xmax)
ymax = float(ymax)

In [23]:
from loguru import logger
from satio_pc.reader import S2TileReader
from satio_pc.preprocessing.clouds import preprocess_scl
from satio_pc.sentinel2 import BANDS_RESOLUTION
from satio_pc.geotiff import slash_tile
from satio_pc.utils.azure import AzureBlobReader

azure = AzureBlobReader(connect_str,
                        container_name)

bounds = xmin, ymin, xmax, ymax

fn = f'evotrain_v2_{year}_{patch_id}.tif'
dst_fn = f"evotrain/v2/{year}/{slash_tile(tile)}/{fn}"
if azure.check_file_exists(dst_fn):
    logger.warning(f"Target {dst_fn} exists, skipping...")

start_date = f'{year}-01-01'
end_date = f'{year + 1}-01-01'

max_cloud_cover = settings['max_cloud_cover']

reader = S2TileReader(tile,
                      start_date,
                      end_date,
                      max_cloud_cover)

In [7]:
def _extract_loc(patch_id, tile, epsg, xmin, ymin, xmax, ymax, year):
    from loguru import logger
    from satio_pc.reader import S2TileReader
    from satio_pc.preprocessing.clouds import preprocess_scl
    from satio_pc.sentinel2 import BANDS_RESOLUTION
    from satio_pc.geotiff import slash_tile
    from satio_pc.utils.azure import AzureBlobReader
    
    azure = AzureBlobReader(connect_str,
                            container_name)
    
    bounds = xmin, ymin, xmax, ymax
    
    fn = f'evotrain_v2_{year}_{patch_id}.tif'
    dst_fn = f"evotrain/v2/{year}/{slash_tile(tile)}/{fn}"
    if azure.check_file_exists(dst_fn):
        logger.warning(f"Target {dst_fn} exists, skipping...")
        return True
    
    start_date = f'{year}-01-01'
    end_date = f'{year + 1}-01-01'

    max_cloud_cover = settings['max_cloud_cover']
    
    reader = S2TileReader(tile,
                          start_date,
                          end_date,
                          max_cloud_cover)
    
    logger.info("Loading and preparing SCL mask")
    scl = reader.read(bounds, epsg, ['SCL'], max_workers=max_workers)

    scl_mask = preprocess_scl(scl,
                              **settings['mask'])

    scl_valid_th = settings['scl_valid_th']  # at least 10% of valid pixels
    valid_flag = scl_mask.mask.mean(axis=(2, 3)) > scl_valid_th
    valid_flag = valid_flag.values[:, 0]
    
    scl_mask.mask = scl_mask.mask.sel(time=valid_flag)

    logger.warning(f"Keeping {valid_flag.sum()} / {valid_flag.size} products. "
                   "Discarded from the SCL filter.")
    reader._items = [i for i, b in zip(reader._items, valid_flag) if b]

    scl20 = scl_mask.mask
    scl10 = scl20.ewc.rescale()

    bands = settings['bands']

    bands_10m = [b for b in bands if BANDS_RESOLUTION[b] == 10]
    bands_20m = [b for b in bands if BANDS_RESOLUTION[b] == 20]
    
    logger.info("Loading and preparing 10m bands")
    # b10 = load_bands(bands_10m, scl10)
    b10 = reader.read(bounds, epsg, bands_10m, max_workers)
    b10 = b10.ewc.harmonize()
    b10 = b10.ewc.mask(scl10)
    b10 = b10.ewc.composite(**settings['composite'])
    b10 = b10.ewc.interpolate()
    b10 = b10 / 10000.
    
    logger.info("Loading and preparing 20m bands")
    # b20 = load_bands(bands_20m, scl20)
    b20 = reader.read(bounds, epsg, bands_20m, max_workers)
    b20 = b20.ewc.harmonize()
    b20 = b20.ewc.mask(scl20)
    b20 = b20.ewc.composite(**settings['composite'])
    b20 = b20.ewc.interpolate()
    b20 = b20 / 10000.
    
    b20 = b20.ewc.rescale()

    s2 = xr.concat([b10, b20], dim='band')

    logger.info("Computing indices")
    s2_vi = s2.ewc.indices(settings['indices'])

    logger.info("Computing percentiles")
    q = settings['percentiles']
    ps = [s.ewc.percentile(q, name_prefix='s2') for s in (s2, s2_vi)]

    # fix time to same timestamp (only 1) to avoid concat issues
    # (different compositing settings for s2 and s1)
    for p in ps:
        p['time'] = ps[0].time

    # scl aux 10m
    scl10_aux = scl_mask.aux.ewc.rescale(scale=2, order=1)
    scl10_aux['time'] = ps[0].time

    final = xr.concat(ps + [scl10_aux], dim='band')
    final.name = f's2-{patch_id}'
    
    logger.info("Saving features")
    final.isel(time=0).ewc.save_features(fn, bounds, epsg)
    
    logger.info(f"Uploading features to {dst_fn}")
    azure.upload_file(fn,
                      dst_fn,
                      overwrite=True)
    
    logger.info(f"Cleaning...")
    os.remove(fn)
    
    logger.success("Done")
    return True


def extract_loc(tup):
    patch_id, tile, epsg, xmin, ymin, xmax, ymax, year = tup
    xmin, ymin, xmax, ymax = list(map(float, (xmin, ymin, xmax, ymax)))
    epsg = int(epsg)
    year = int(year)
    
    try:
        from loguru import logger
        return _extract_loc(patch_id, tile, epsg, xmin, ymin, xmax, ymax, year)
    except Exception as e:
        logger.exception(e)
        return False

In [8]:
import pandas as pd
import xarray as xr

locs = pd.read_csv(locs_fn)

max_workers = 10

cols = ['patch_id', 'tile', 'epsg', 'xmin', 'ymin', 'xmax', 'ymax']
locs = locs[cols]

In [9]:
from satio_pc.utils.azure import AzureBlobReader
azure = AzureBlobReader(connect_str,
                        container_name)

# keys = azure.list_files()
def list_files(prefix=None):
    files = set()
    for blob in azure.container_client.list_blobs(name_starts_with=prefix):
        if blob.name[-1] != "/":
            files.add(blob.name)
    return files

In [10]:
year = 2022

keys = list_files(f'evotrain/v2/{year}/')

done_patch_ids = list(map(lambda key: "_".join(key.split('.')[0].split('_')[-3:]), keys))

locs = locs[~locs.patch_id.isin(done_patch_ids)]
locs.shape

(64435, 7)

### Cluster processing

In [11]:
# 'patch_id', 'tile', 'epsg', 'xmin', 'ymin', 'xmax', 'ymax'
args = [(loc.patch_id,
         loc.tile,
         loc.epsg,
         loc.xmin,
         loc.ymin,
         loc.xmax,
         loc.ymax,
         year)
        for loc in locs.itertuples()
        ]
len(args)

64435

In [12]:
args = np.array(args)

In [16]:
import warnings
warnings.filterwarnings("ignore")

In [13]:
final = extract_loc(args[0])

[32m2023-07-21 20:58:24.677[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m30[0m - [1mLoading and preparing SCL mask[0m
[32m2023-07-21 20:58:29.424[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m54[0m - [1mLoading and preparing 10m bands[0m
[32m2023-07-21 20:58:34.195[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m63[0m - [1mLoading and preparing 20m bands[0m
[32m2023-07-21 20:58:35.492[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m76[0m - [1mComputing indices[0m
[32m2023-07-21 20:58:35.504[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m79[0m - [1mComputing percentiles[0m
[32m2023-07-21 20:58:35.718[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m95[0m - [1mSaving features[0m
[32m2023-07-21 20:58:35.729[0m | [1mINFO    [0m | [36msatio_pc.geotiff[0m:[36msave_features_geotiff[0m:[36m207[0m - [1mSaving evotrain_v2_2021

In [15]:
import dask

extract_delayed = dask.delayed(extract_loc)

lazy_results = [extract_delayed(ag)
                for ag in args]

In [16]:
results = dask.compute(*lazy_results)


KeyboardInterrupt



# finish download missing

In [10]:
locs_all = pd.read_csv(locs_fn)

args = []

for year in range(2018, 2023):
    
    print(year)
    keys = list_files(f'evotrain/v2/{year}/')

    done_patch_ids = list(map(lambda key: "_".join(key.split('.')[0].split('_')[-3:]), keys))

    locs = locs_all[~locs_all.patch_id.isin(done_patch_ids)]
    
    args += [(loc.patch_id,
         loc.tile,
         loc.epsg,
         loc.xmin,
         loc.ymin,
         loc.xmax,
         loc.ymax,
         year)
        for loc in locs.itertuples()
        ]
    print(year, locs.shape)
    
args = np.array(args)

args.size

2018
2018 (199, 28)
2019
2019 (75, 28)
2020
2020 (59, 28)
2021
2021 (260, 28)
2022
2022 (741, 28)


10672

In [11]:
np.save('args_final.npy', args)

In [14]:
import dask

extract_delayed = dask.delayed(extract_loc)

lazy_results = [extract_delayed(ag)
                for ag in args]

results = dask.compute(*lazy_results)

# bags

In [15]:
import dask
import dask.bag as db

npartitions = 3000

b = db.from_sequence(args, npartitions=npartitions)
b = b.map(extract_loc)

In [None]:
results = b.compute()

In [None]:
for ag in args[:20]:
    extract_loc(ag)

In [18]:
ag

array(['60MUD_052_53', '60MUD', '32760', '388320.0', '9850120.0',
       '389600.0', '9851400.0', '2018'], dtype='<U32')

In [19]:
tup = args[0]

In [37]:
tup

array(['36NYF_114_40', '36NYF', '32636', '740920.0', '-9780.0',
       '742200.0', '-8780.0', '2018'], dtype='<U32')

In [42]:
xmax - xmin, ymax - ymin

(1280.0, 1000.0)

In [72]:
d = locs_all.apply(lambda r: ((r.xmax - r.xmin) + (r.ymax - r.ymin)), axis=1)

In [73]:
d.unique()

array([2560., 2280.])

In [74]:
locs_inv = locs_all[d != 2560]

In [75]:
locs_inv.shape

(17, 28)

In [81]:
for ag in args:
    
    if ag[0] in locs_inv.patch_id.values:
        print(ag[0])
        continue
    else:
        break
ag

36NYF_114_40
36NZF_113_43
35NRA_116_42
35NQA_120_31
32NRF_112_45
52NCF_112_41
49NEA_112_47


array(['01KAB_095_00', '01KAB', '32701', '171640.0', '8116800.0',
       '172920.0', '8118080.0', '2018'], dtype='<U32')

In [82]:
ag

array(['01KAB_095_00', '01KAB', '32701', '171640.0', '8116800.0',
       '172920.0', '8118080.0', '2018'], dtype='<U32')

In [83]:
tup = ag
patch_id, tile, epsg, xmin, ymin, xmax, ymax, year = tup
xmin, ymin, xmax, ymax = list(map(float, (xmin, ymin, xmax, ymax)))
epsg = int(epsg)
year = int(year)

In [84]:
bounds = xmin, ymin, xmax, ymax

In [85]:
xmax - xmin, ymax - ymin

(1280.0, 1280.0)

In [86]:
from satio_pc.reader import S2TileReader

In [87]:
start_date, end_date = '2018-01-01', '2019-01-01'

In [88]:
reader = S2TileReader(tile,
                      start_date,
                      end_date,
                      90)

# logger.info("Loading and preparing SCL mask")
# scl = reader.read(bounds, epsg, ['SCL'], max_workers=max_workers)

In [89]:
stac = reader.items

In [90]:
stac.items[0].assets['SCL'].href

IndexError: list index out of range

In [64]:
bounds

(270380.0, 5290240.0, 271660.0, 5291240.0)

In [65]:
ag

array(['19GBP_116_47', '19GBP', '32719', '270380.0', '5290240.0',
       '271660.0', '5291240.0', '2022'], dtype='<U32')

In [67]:
bounds

(270380.0, 5290240.0, 271660.0, 5291240.0)

In [66]:
reader.read(bounds, epsg, ['SCL'])