In [8]:
import pickle as pkl
import os
import geopandas as gpd
import numpy as np
import pystac_client
import planetary_computer
import rioxarray
import xarray as xr
import rasterio
from rasterio.enums import Resampling
import pickle as pkl
from datetime import timedelta
from rioxarray.exceptions import NoDataInBounds

with open('../data/failed_downloads.pkl', 'rb') as f:
    failed_downloads = pkl.load(f)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1/",
    modifier=planetary_computer.sign_inplace,
)
query = {"eo:cloud_cover": {"lt": 1}}

collections=["sentinel-2-l2a"]

# Dates to query
start = "2016-01-01"
end = "2024-08-31"

bands = ["B02", "B03", "B04", "B08"]
bands_map = {"B02": "blue", "B03": "green", "B04": "red", "B08": "nir"}
gsd = 10
epsg = 32617

In [5]:
# Function to query STAC items for a tile
def query_stac_tile(tile_geometry, catalog, start, end, 
                    query, collections=["sentinel-2-l2a"], 
                    limit=1000):
    # Get the bounds of the tile in WGS84
    tile_wgs84 = gpd.GeoSeries([tile_geometry], crs="EPSG:32617").to_crs("EPSG:4326").iloc[0]
    minx, miny, maxx, maxy = tile_wgs84.bounds
    bbox = [minx, miny, maxx, maxy]

    # Perform the search
    search = catalog.search(
        collections=collections,
        bbox=bbox,
        datetime=f"{start}/{end}",
        limit=limit,
        query=query
    )
    # Get the items from the search results
    items = list(search.item_collection())
    return items

def get_subregion(dataset, bounds):
    min_x, min_y, max_x, max_y = bounds
    # Subset the dataset using xarray's sel function
    subregion = dataset.sel(
        x=slice(min_x, max_x),  # X-coordinate bounds
        y=slice(max_y, min_y)   # Y-coordinate bounds (flip due to coordinate system)
    )
    return subregion

# Function to get images from item
def get_images_from_item(item, bands, output_file_path, 
                         chunk_size=2048, dtype=np.float32, 
                         crs="EPSG:32617", bounds=None):
    # Check if the file already exists
    if not os.path.exists(output_file_path):

        band_datasets = []
        # Loop through bands and collect data
        for band in bands:
            # Sign the asset URLs
            asset_href = planetary_computer.sign(item.assets[band].href)
            # Open the image using rioxarray
            with rasterio.Env():
                ds = rioxarray.open_rasterio(
                    asset_href,
                    chunks={"band": -1, "x": chunk_size, "y": chunk_size},
                    lock=False
                ).astype(dtype)
                if ds.rio.crs != crs:
                    ds = ds.rio.reproject(
                        crs, 
                        resampling=Resampling.nearest,
                        num_threads=2
                    )
                if bounds is not None:
                    ds = get_subregion(ds, bounds)
                band_datasets.append(ds)

        # Stack bands into a single dataset
        stacked_ds = xr.concat(band_datasets, dim='band')

        # Store time as an attribute
        naive_datetime = item.datetime.replace(tzinfo=None)
        time_value = np.datetime64(naive_datetime, 'ns')
        stacked_ds.attrs['time'] = str(time_value)
        
        # Save the stacked dataset to a single GeoTIFF file
        # stacked_ds.rio.to_raster(output_file_path)
        # Write the data using Dask and rioxarray, with windowed=True and tiled=True
        with rasterio.Env(GDAL_CACHEMAX=512):  # Set cache size to 512 MB
            stacked_ds.rio.to_raster(
                output_file_path,
                tiled=True,
                windowed=True,
                blockxsize=256,
                blockysize=256,
                compress="deflate",
                num_threads=2,
                bigtiff='yes'
            )
    return output_file_path

# Function to clean up bounds for filename
def clean_bounds(bounds):
    minx, miny, maxx, maxy = bounds
    # Round to 3 decimal places and remove any special characters
    minx_str = f"{minx:.3f}".replace('.', '_')
    miny_str = f"{miny:.3f}".replace('.', '_')
    maxx_str = f"{maxx:.3f}".replace('.', '_')
    maxy_str = f"{maxy:.3f}".replace('.', '_')
    # Combine into a single string
    bounds_str = f"{minx_str}_{miny_str}_{maxx_str}_{maxy_str}"
    return bounds_str

In [9]:
def requery_failed_downloads(failed_downloads, catalog, query, 
                          collections=["sentinel-2-l2a"], limit=1000, 
                          save_dir="../data/missing_tiles", buffer_days=60):

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    no_data_in_bounds = []
    for idx, failed_download in enumerate(failed_downloads):
        tile_polygon = failed_download['tile_geometry']  
        query_date = failed_download['query_date']
        buffer_days = 0  # Start without buffer
        print(f"Re-querying tile for date: {query_date} at coordinates: {failed_download['bounds']}")
        
        while buffer_days <= buffer_days: 
            start_date = query_date - timedelta(days=buffer_days)
            end_date = query_date + timedelta(days=buffer_days)
            
            # Query the STAC catalog for the missing tile
            items = query_stac_tile(tile_polygon, catalog, start_date.isoformat(), end_date.isoformat(), 
                                    query, collections=collections, limit=limit)
            
            if items:
                for item in items:
                    saved = False
                        # Define the output file path before querying
                    tile_bounds = tile_polygon.bounds
                    bounds_str = clean_bounds(tile_bounds)
                    band_sfx = "_".join(bands)
                    output_file = os.path.join(save_dir, f"tile_{idx}_{bounds_str}_{buffer_days}_{query_date.isoformat()}_{band_sfx}.tif")
                    
                    # Save the queried image
                    try:
                        get_images_from_item(item, bands, output_file, bounds=tile_polygon.bounds)
                        print(f"Saved re-queried tile to {output_file}")
                        saved = True
                    except NoDataInBounds as e:
                        print(e)
                if not saved:
                    no_data_in_bounds.append({
                            'idx': idx,
                            'tile_polygon': tile_polygon,
                            'query_date': query_date 
                        })
                    print(f"Appended {tile_polygon.bounds} {query_date} to `no_data_in_bounds`.")
                
                break  # Stop if data is found and saved
            
            buffer_days += 2  # Increase the buffer by 2 days if no data found

        if buffer_days > buffer_days:
            print(f"Failed to retrieve data for tile {failed_download['bounds']} on date {query_date} within buffer range.")
    return no_data_in_bounds
            

requery_failed_downloads(failed_downloads, catalog, query=query, save_dir="../data/missing_tiles")

Re-querying tile for date: 2016-04-14 at coordinates: (709800.0, 3930600.0, 710400.0, 3931200.0)
Saved re-queried tile to ../data/missing_tiles\tile_0_709800_000_3930600_000_710400_000_3931200_000_0_2016-04-14_B02_B03_B04_B08.tif
Saved re-queried tile to ../data/missing_tiles\tile_0_709800_000_3930600_000_710400_000_3931200_000_0_2016-04-14_B02_B03_B04_B08.tif
Re-querying tile for date: 2016-06-10 at coordinates: (709800.0, 3930600.0, 710400.0, 3931200.0)
Saved re-queried tile to ../data/missing_tiles\tile_0_709800_000_3930600_000_710400_000_3931200_000_0_2016-06-10_B02_B03_B04_B08.tif
Saved re-queried tile to ../data/missing_tiles\tile_0_709800_000_3930600_000_710400_000_3931200_000_0_2016-06-10_B02_B03_B04_B08.tif
Re-querying tile for date: 2016-11-27 at coordinates: (709800.0, 3930600.0, 710400.0, 3931200.0)
Saved re-queried tile to ../data/missing_tiles\tile_0_709800_000_3930600_000_710400_000_3931200_000_0_2016-11-27_B02_B03_B04_B08.tif
Saved re-queried tile to ../data/missing_til

KeyboardInterrupt: 