In [None]:
import rioxarray
import xarray as xr
import numpy as np

import os
import glob
import pathlib
import json
import boto3
import s3fs
import shutil
from typing import Any, Dict, Generic, Iterable, List, Optional, TypeVar, Union, cast
from datetime import datetime
from pprint import pprint

import geopandas as gpd
from shapely.geometry import box, mapping, Polygon, LinearRing
import warnings
warnings.filterwarnings("ignore")

import pystac
from pystac import Catalog, get_stac_version, Collection, Item, STACObject, CatalogType, Link
from pystac.stac_io import DefaultStacIO, StacIO
from pystac.extensions.projection import AssetProjectionExtension, ProjectionExtension

In [None]:
# Delete this block once we are in GDAL >= 3.5

awsdir = pathlib.Path("~/.aws").expanduser()
awsdir.mkdir(exist_ok=True)
awscred = awsdir / "credentials"

session = boto3.session.Session()
credentials = session.get_credentials()
credentials = credentials.get_frozen_credentials()

#Credentials are refreshable, so accessing your access key / secret key
#separately can lead to a race condition. Use this to get an actual matched
#set. 

outstring = f"""
[default]
aws_access_key_id={credentials.access_key}
aws_secret_access_key={credentials.secret_key}
aws_session_token={credentials.token}
"""
with awscred.open(mode='w') as f:
    f.write(outstring)

In [None]:
Bucket = "dh-shift-curated"

In [None]:
# Retrieve STAC catalog
catalog_url = "efs/SBG-SHIFT-STAC/catalog.json"
collection_url = "/efs/SBG-SHIFT-STAC/collection.json"

In [None]:
catalog = Catalog.from_file(catalog_url)
catalog.describe()

In [None]:
#As new data becomes available for AVIRIS-NG, add to catalog

# Read in data from S3 bucket
def get_zarrs(Bucket, dataset_date):
    s3 = boto3.client('s3')
    Prefix = f'aviris/{dataset_date}'
    kwargs = {'Bucket': Bucket, 'Prefix': Prefix}
    substring = '100-100-100.zarr'
    links = []
    while True:
        objects = s3.list_objects_v2(**kwargs)
        for obj in objects['Contents']:
            if substring in obj['Key']:
                key = obj['Key']
                url = key[:key.index(substring)+ len(substring)]
                zarr = url.replace(f"aviris/{dataset_date}/", "")
                links.append(str(zarr))
            
        try:
            kwargs['ContinuationToken'] = objects['NextContinuationToken']
        except KeyError:
            break
     
    data_set = set(links)
    data = list(data_set)
    #print(data) # see results
    print("Data retrieved")
    return data

def create_items_add_assets(Bucket, dataset_date, zarr):
    Prefix = f"aviris/{dataset_date}/{zarr}"
    s3_key = os.path.join(Bucket, Prefix)
    
    # Open flight path zarr
    s3 = s3fs.S3FileSystem(anon=False, client_kwargs=dict(region_name='us-west-2'))
    s3_path = f"s3://{s3_key}"
    store = s3fs.S3Map(root=s3_path, s3=s3, check=False)
    ds = xr.open_zarr(store=store, decode_coords="all", consolidated=True)
    print("Zarr read done!")
    
    # Set dimensions and coordinates
    ds = ds.rio.set_spatial_dims(x_dim='x', y_dim='y')
    ds = ds.set_coords(('Easting', 'Northing'))

     # Calculate extent, bbox, and footprint using Easting & Northing
    ul = ds.isel(x=0, y=0)  # Upper left corner
    ur = ds.isel(x=-1, y=0) # Upper right corner
    ll = ds.isel(x=0, y=-1) # Lower left corner
    lr = ds.isel(x=-1, y=-1) # Lower right corner

    ul2 = (np.min(ul.Easting.values), np.max(ul.Northing.values))
    ur2 = (np.max(ur.Easting.values), np.max(ur.Northing.values))
    ll2 = (np.min(ll.Easting.values), np.min(ll.Northing.values))
    lr2 = (np.max(lr.Easting.values), np.min(lr.Northing.values))
    
    extent = Polygon([ul2, ll2, lr2, ur2])
    footprint = mapping(extent)

    # save flight outline as GeoJSON and upload to s3
    substring = '_100-100-100.zarr'
    item_name = zarr[:zarr.index(substring)]

    gpd.GeoDataFrame({"geometry": extent}, index=[0]).to_file(f"{item_name}_flight_outline.geojson", driver='GeoJSON')

    s3 = boto3.client('s3')
    kwargs = {'Bucket': Bucket, 'Key': f"aviris/{dataset_date}/{item_name}_flight_outline.geojson"}
    s3.put_object(**kwargs)
    
    bbox = extent.bounds
    
    # create datetime object from date string
    dt = datetime.strptime(dataset_date, f"%Y%m%d")
    
    item_url = "/efs/SBG-SHIFT-STAC/AVIRIS-NG/{item_name}.json"

    # item should be flight line
    item = pystac.Item(id=f"{item_name}",
                    geometry=footprint, 
                    bbox=bbox, 
                    datetime=dt,
                    href=item_url,
                    stac_extensions=['https://stac-extensions.github.io/projection/v1.0.0/schema.json'],
                    properties={},
                    collection = 'AVIRIS-NG'
                      )
    # add instrument metadata
    item.common_metadata.instruments = ['AVIRIS-NG']

    # add projection extension to item
    proj_ext = ProjectionExtension.ext(item, add_if_missing = True)
    proj_ext.epsg = 32610
    print("Item created!")
    
    # Add dataset asset
    print("Adding assets")
    item.add_asset(
        key="dataset",
        asset=pystac.Asset(
            href= f"https://dh-shift-curated.s3.us-west-2.amazonaws.com/aviris/{dataset_date}/{zarr}"
        )
    )

    # extend the asset with projection extension
    asset_ext = AssetProjectionExtension.ext(item.assets["dataset"])
    asset_ext.epsg = 32610
    asset_ext.bbox = bbox
    asset_ext.geometry = footprint

    # Add flight outline GeoJSON asset
    item.add_asset(
        key="Flight Outline",
        asset=pystac.Asset(
            href= f"https://dh-shift-curated.s3.us-west-2.amazonaws.com/aviris/{dataset_date}/{item_name}_flight_outline.geojson",
            media_type=pystac.MediaType.GEOJSON
            )
    )
    print("Assets added!")
    zarrs_to_items[item_name] = item

In [None]:
# additional dates after 20220412 to be added later
dates = ['20220420', '20220429', '20220503' # plus any other dates
        ]

zarrs_to_items = {}

for dataset_date in dates:
    data = get_zarrs(Bucket, dataset_date)
    
    for zarr in data:
        substring = '_100-100-100.zarr'
        item_name = zarr[:zarr.index(substring)]
        print(f"Starting {item_name}")
        
        # create STAC item and add assets, including Zarr dataset
        create_items_add_assets(Bucket, dataset_date, zarr)

In [None]:
# add items to collection
collection.add_items(zarrs_to_items.values())
collection.describe()

In [None]:
# Add STAC Collection to STAC catalog
catalog.add_child(collection)
catalog.describe()

In [None]:
# To add labels and of GeoTIFFs are created:
# https://pystac.readthedocs.io/en/stable/tutorials/pystac-spacenet-tutorial.html

In [None]:
# Save new modified catalog to efs
catalog.normalize_hrefs(catalog_url)
catalog.save(catalog_type=CatalogType.SELF_CONTAINED)

In [None]:
# open newly-modified catalog
copycat = Catalog.from_file(catalog_url)

print('Done!')