In [1]:
from datetime import datetime
import os
import pprint
import random

import ee
from google.cloud import storage
import numpy as np
import pandas as pd

STORAGE_CLIENT = storage.Client(project='openet')


In [2]:
ee.Initialize(project='ee-cmorton', opt_url='https://earthengine-highvolume.googleapis.com')

In [3]:
# The minimum number of months in the target year
min_month_count = 3

# TODO: Add support for setting a minimum number of months in the year
#   and minimum number of months in the growing season
# min_month_count = 6
# min_gs_month_count = 3

# Exclude 2016 from the statistics since there is not a full prior year to interpolate from
# Including 2024 even though 2025 is not complete
stats_years = list(range(2017, 2025))


In [4]:
# CGM - We could pull separate sets of points for each NLCD year but this doesn't seem that useful
nlcd_years = [2024]
# nlcd_years = [2024, 2023, 2022, 2021, 2020, 2019, 2018, 2017, 2016]

ensemble_coll_id = 'projects/openet/assets/ensemble/conus/gridmet/monthly/v2_1'
reference_et_coll_id = 'projects/openet/assets/reference_et/conus/gridmet/monthly/v1'
nlcd_coll = (
    ee.ImageCollection('projects/sat-io/open-datasets/USGS/ANNUAL_NLCD/LANDCOVER')
    .select([0], ["landcover"])
)
mgrs_coll = (
    ee.FeatureCollection('projects/openet/assets/mgrs/tiles/modified')
    .filterBounds(ee.Geometry.BBox(-125, 25.5, -66.5, 48.5))
    .filterBounds(
        ee.Image("projects/openet/assets/features/water_mask_buffer")
        .eq(0).rename(['mask']).geometry()
    )
)
mgrs_zones = [
    '10S', '10T', '10U', '11S', '11T', '11U', '12S', '12T', '12U', 
    '13R', '13S', '13T', '13U', '14R', '14S', '14T', '14U', '15R', '15S', '15T', '15U', 
    '16R', '16S', '16T', '17R', '17S', '17T', '18S', '18T', '19T'
    # These two zones are too small and don't have enough points
    # '12R', '16U'
]

# Identify the MGRS tiles that have a centroid that intersects the GRIDMET mask
def ftr_area(ftr):
    return ftr.set({
        'mask': ee.Image('projects/openet/assets/mgrs/conus/gridmet/data_mask')
            .unmask().reduceRegion(ee.Reducer.first(), ftr.geometry().centroid(), 4000).get('mask'),
        # 'area': ftr.geometry().area(),
    })
mgrs_coll = ee.FeatureCollection(mgrs_coll.map(ftr_area)).filterMetadata('mask', 'equals', 1)
mgrs_tiles = sorted(list(mgrs_coll.aggregate_array('mgrs').getInfo()))
mgrs_tile_skip_list = ['15UUQ', '16RBT', '17RNM', '18SVE', '19TEN']

# Local folder for points CSV files
points_folder = 'points'
if not os.path.isdir(points_folder):
    os.makedirs(points_folder)

points_csv = 'gap_fill_test_points.csv'
points_coll_id = 'projects/ee-cmorton/assets/gap_fill_test_points'

bucket_name = 'openet_temp'
bucket = STORAGE_CLIENT.bucket(bucket_name)

months = list(range(1, 13))

print_n = 500


In [5]:
# Build a list of test points stratified by NLCD type
# Using a default of 2 points so that the edge tiles can drop to having 1 point
max_points = 2
overwrite_flag = False

# Building separate lists by year with the idea being that the NLCD can be different per year
# It would probably make more sense to build a single point list 
#   and then extract the NLCD values for each/all years
# for nlcd_year in nlcd_years:
for nlcd_year in [2024]:
    print(f'NLCD: {nlcd_year}')

    nlcd_img = nlcd_coll.filterDate(f'{nlcd_year}-01-01', f'{nlcd_year+1}-01-01').first()

    # Masking out water, perennial ice/snow, and urban land covers
    nlcd_img = nlcd_img.updateMask(
        nlcd_img.eq(11).Or(nlcd_img.eq(12))
        .Or(nlcd_img.eq(21)).Or(nlcd_img.eq(22)).Or(nlcd_img.eq(23))
        .Not()
    )   
     
    for mgrs_zone in mgrs_zones:
        points_mgrs_year_csv = os.path.join(points_folder, f'points_{mgrs_zone}_{nlcd_year}.csv')
        if os.path.isfile(points_mgrs_year_csv) and not overwrite_flag:
            # print(f'{mgrs_zone} - csv already exists, skipping')
            continue
        print(f'{mgrs_zone}')        

        point_list = []
        for mgrs_tile in mgrs_tiles:
            if mgrs_tile[:3] != mgrs_zone:
                continue
            elif mgrs_tile in mgrs_tile_skip_list:
                continue
            
            mgrs_geom = mgrs_coll.filterMetadata('mgrs', 'equals', mgrs_tile).first().geometry()
            
            # Adjust the minimum number of points based on the size of the MGRS tile
            mgrs_points = mgrs_geom.area().divide(10000000000).multiply(max_points).round().clamp(0, max_points).getInfo()
            print(f'  {mgrs_tile} - Points: {mgrs_points}')
            if mgrs_points == 0:
                continue
        
            point_coll = nlcd_img.stratifiedSample(
                numPoints=mgrs_points, 
                classBand='landcover', 
                region=mgrs_geom, 
                dropNulls=True, 
                geometries=True,
                seed=nlcd_year, 
            )
            try:
                point_info = point_coll.getInfo()['features']
            except Exception as e:
                print('  Error getting point collection, skipping')
                continue
            
            point_list.extend([
                {
                    'latitude': pnt['geometry']['coordinates'][1],
                    'longitude': pnt['geometry']['coordinates'][0],
                    'mgrs_tile': mgrs_tile, 
                    'mgrs_zone': mgrs_zone, 
                    'nlcd': pnt['properties']['landcover'], 
                    'nlcd_year': nlcd_year, 
                }
                for pnt in point_info
            ])

        print(f'  Total Points: {len(point_list)}')
        pd.DataFrame(point_list).to_csv(points_mgrs_year_csv, index=False)

print('\nDone')

NLCD: 2024

Done


In [6]:
# Building a single points dataframe and CSV from the MGRS grid zone points CSV files
overwrite_flag = False

# Read the separate points CSV files into a single dataframe
points_df_list = [
    pd.read_csv(os.path.join(points_folder, f'points_{mgrs_zone}_{nlcd_year}.csv'), index_col=None, header=0)
    for nlcd_year in nlcd_years
    for mgrs_zone in mgrs_zones
    if os.path.isfile(os.path.join(points_folder, f'points_{mgrs_zone}_{nlcd_year}.csv'))
]
points_df = pd.concat(points_df_list, axis=0, ignore_index=True)
print(f'Points: {len(points_df.index)}')

# The mgrs_zone value will eventually be added to the csv files
points_df['mgrs_zone'] = points_df['mgrs_tile'].str.slice(0, 3)

# Add a unique index to the points dataframe
points_df['index_group'] = points_df.groupby(['mgrs_tile', 'nlcd']).cumcount()
points_df['point_id'] = (
    points_df["mgrs_tile"].str.upper() + '_' +
    'nlcd' + points_df["nlcd"].astype(str).str.zfill(2) + '_' +
    points_df["index_group"].astype(str).str.zfill(2)
)
del points_df['index_group']

# Round the lat and lon to 8 decimal places (probably should be 6)
points_df['latitude'] = round(points_df['latitude'], 8)
points_df['longitude'] = round(points_df['longitude'], 8)

# Write to CSV
# print(points_df.head())
if not os.path.isfile(points_csv) or overwrite_flag:
    print('Writing points csv')
    points_df.to_csv(points_csv, index=False)


Points: 17084


In [7]:
# Building points feature collection
overwrite_flag = False

# Upload to bucket
bucket_path = f'gs://{bucket_name}/{points_csv}'
blob = bucket.blob(bucket_path.replace(f'gs://{bucket_name}/', ''))
if overwrite_flag or not blob.exists():
    print('Uploading csv to bucket')
    blob = bucket.blob(bucket_path.replace(f'gs://{bucket_name}/', ''))
    blob.upload_from_filename(points_csv)

# Ingest into GEE
if overwrite_flag and ee.data.getInfo(points_coll_id):
    print('Removing existing collection')
    ee.data.deleteAsset(points_coll_id)
    
if not ee.data.getInfo(points_coll_id):
    print('Starting csv ingest')
    manifest = {'name': points_coll_id, 'sources': [{'uris': [bucket_path]}]}
    ee.data.startTableIngestion(None, manifest)

# DEADBEEF - The full collection is to large to export directly to asset
# # Build the feature collection and export to asset
# points_coll = ee.FeatureCollection([
#     ee.Feature(
#         ee.Geometry.Point(round(point['lon'], 6), round(point['lat'], 6)), 
#         {
#             'point_id': point['point_id'],
#             'mgrs': point['mgrs'], 
#             'nlcd': point['nlcd'], 
#             'nlcd_year': point['nlcd_year'], 
#         }
#     )
#     for point_i, point in point_df.iterrows()
# ])

# if not ee.data.getInfo(points_coll_id):
#     task = ee.batch.Export.table.toAsset(
#         collection=points_coll, 
#         description='openet_gap_filling_test_points', 
#         assetId=points_coll_id, 
#     )
#     task.start()
#     task_id = task.status()['id']
#     print(f'{task_id}')

In [8]:
# Pull ensemble ET and ETo time series at all points
# Run extractions by MGRS tile to better control crs and transform parameters?

overwrite_flag = False

for mgrs_zone in mgrs_zones:

    data_mgrs_csv = os.path.join('data', f'data_{mgrs_zone}.csv')
    if os.path.isfile(data_mgrs_csv) and not overwrite_flag:
        # print(f'{mgrs_zone} - csv file already exists, skipping')
        continue
        
    mgrs_points_coll = (
        ee.FeatureCollection(points_coll_id)
        .filterMetadata('mgrs_zone', 'equals', mgrs_zone)
        # .filterMetadata('nlcd_year', 'equals', 2024)
    )
    print(f'{mgrs_zone} - Points: {mgrs_points_coll.size().getInfo()}')
    # pprint.pprint(mgrs_points_coll.getInfo())
    
    # TODO: Lookup all the crs and transform values for the MGRS tile list ahead of time
    # Get the crs and transform for the MGRS tile from one of the ensemble images
    mgrs_info = (
        ee.ImageCollection(ensemble_coll_id)
        .filterMetadata('mgrs_tile', 'equals', mgrs_zone)
        .first().select(['et_ensemble_mad']).getInfo()
    )

    # Get the list of images to read from 
    image_id_list = (
        ee.ImageCollection(ensemble_coll_id)
        .filterDate('2015-10-01', '2025-10-01')
        .filterMetadata('mgrs_tile', 'equals', mgrs_zone)
        .aggregate_array('system:index').getInfo()
    )
    
    # Extract values
    image_df_list = []
    for image_i, image_id in enumerate(image_id_list):
        if image_i % 10 == 0:
            print(image_i, image_id)

        src_img = ee.Image(f'{ensemble_coll_id}/{image_id}')
        image_date = ee.Date(src_img.get('system:time_start'))

        # I was getting weird errors when trying to get the ETo and ensemble in one call
        #   so pulling ETo separately using scale parameter
        eto_img = (
            ee.ImageCollection(reference_et_coll_id)
            .filterDate(image_date.advance(-1, 'day'), image_date.advance(1, 'day')).first()
            .select(['eto'])
            .resample('bilinear')
        )
        eto_values_coll = eto_img.reduceRegions(
            collection=mgrs_points_coll, 
            reducer=ee.Reducer.first(), 
            scale=30,
            #crs=mgrs_info['bands'][0]['crs'], 
            #crsTransform=mgrs_info['bands'][0]['crs_transform'], 
        )
        try:
            eto_values = eto_values_coll.getInfo()
        except Exception as e:
            print(f'ETo Exception: {e}')
            continue
        eto_values = {
            ftr['properties']['point_id']: round(ftr['properties']['first']) 
            for ftr in eto_values['features']
            if 'first' in ftr['properties'].keys()
        }

        # Extract the ensemble ET and count
        image_values_coll = (
            src_img.select(['et_ensemble_mad', 'et_ensemble_mad_count'], ['et', 'count'])
            .reduceRegions(
                collection=mgrs_points_coll, 
                reducer=ee.Reducer.first(), 
                crs=mgrs_info['bands'][0]['crs'], 
                crsTransform=mgrs_info['bands'][0]['crs_transform'], 
            )
        )
        try:
            image_values = image_values_coll.getInfo()
        except Exception as e:
            print(f'  ET Exception: {e}')
            print(f'  {image_id}')
            continue
        
        # Build a dataframe of all the daily values
        # Use the ETo value to determine if the point is valid and should be kept
        image_df = [
            {
                'date': datetime.strptime(image_id.split('_')[1], '%Y%m%d').strftime('%Y-%m-%d'), 
                'point_id': pnt['properties']['point_id'], 
                'mgrs_tile': pnt['properties']['mgrs_tile'], 
                'mgrs_zone': pnt['properties']['mgrs_zone'], 
                'nlcd': pnt['properties']['nlcd'], 
                'nlcd_year': pnt['properties']['nlcd_year'], 
                'count': pnt['properties']['count'],
                'et': pnt['properties']['et'], 
                'eto': eto_values[pnt['properties']['point_id']], 
            } 
            for pnt in image_values['features']
            if pnt['properties']['point_id'] in eto_values.keys() and eto_values[pnt['properties']['point_id']]
        ]
        image_df_list.extend(image_df)

    pd.DataFrame(image_df_list).to_csv(data_mgrs_csv, index=False)

    del mgrs_points_coll, mgrs_info, image_id_list, image_df_list

print('\nDone')



Done
