In [None]:
import json
import multiprocessing
import os
import pprint
import time

import ee
import pandas as pd
from google.cloud import storage
#import retry

from IPython.display import Image, display
#import ipyplot

# gsutil -m rm "gs://openet_temp/skip_scene_stats/2024/*.csv"
# gsutil -m cp "gs://openet_temp/skip_scene_stats/2024/*.csv" ./stats/2024/

BUCKET_NAME = 'openet_temp'
BUCKET_FOLDER = 'skip_scene_stats'
storage_client = storage.Client(project='openet')


In [None]:
ee.Initialize(
    #ee.ServiceAccountCredentials('_', key_file='../../keys/openet-gee.json')
    ee.ServiceAccountCredentials('_', key_file='../../keys/openet-dri-gee.json')
    #project='ee-cmorton',
    #opt_url='https://earthengine-highvolume.googleapis.com'
)


In [None]:
# The Ocean mask is True for water, so flip it for updateMask call so that land pixels are 1
land_mask = ee.Image('projects/openet/assets/features/water_mask').Not()
# Apply the NLCD/NALCMS water mask (anywhere it is water, set the ocean mask 
land_mask = land_mask.where(ee.Image("USGS/NLCD_RELEASES/2020_REL/NALCMS").unmask(18).eq(18), 0)
# land_mask = land_mask.And(ee.Image("USGS/NLCD_RELEASES/2020_REL/NALCMS").unmask(18).neq(18))

# land_mask = ee.Image('projects/openet/assets/meteorology/conus404/ancillary/land_mask')

stats_ws = os.path.join(os.getcwd(), 'stats')
if not os.path.isdir(stats_ws):
    os.makedirs(stats_ws)

# # Use the OpenET ssebop collection for building the WRS2 list for now
# wrs2_list = sorted(
#     # ee.ImageCollection('projects/openet/assets/ssebop/conus/gridmet/landsat/c02')
#     # ee.ImageCollection('projects/openet/assets/intercomparison/ssebop/landsat/c02/v0p2p6')
#     ee.ImageCollection('projects/usgs-gee-nhm-ssebop/assets/ssebop/landsat/c02')
#     .filterDate('2020-01-01', '2024-01-01')
#     .aggregate_histogram('wrs2_tile').keys().getInfo(),
#     reverse=True
# )
# wrs2_list = wrs2_list + ['p018r028']
# # print(len(wrs2_list))

# # Use the OpenET ssebop collection for building the WRS2 list for now
# wrs2_subset_list = sorted(
#     # ee.ImageCollection('projects/openet/assets/intercomparison/ssebop/landsat/c02/v0p2p6')
#     ee.ImageCollection('projects/usgs-gee-nhm-ssebop/assets/ssebop/landsat/c02')
#     .filterDate('2020-01-01', '2024-01-01')
#     .aggregate_histogram('wrs2_tile').keys().getInfo(),
#     reverse=True
# )
# wrs2_subset_list = wrs2_subset_list + ['p018r028']
# wrs2_list = [wrs2 for wrs2 in wrs2_subset_list if wrs2 not in wrs2_list]
# print(len(wrs2_list))

wrs2_skip_list = [
    'p050r026',  # Vancouver Island
    'p048r028',  # OR/WA Coast
    'p042r037',  # San Nicholas Island, California
    'p040r040', 'p039r040',  # Isla Guadalupe
    'p038r043', 'p036r043',  # Baja coast
    'p019r040', 'p018r040',  # West Florida coast
    'p016r043', 'p015r043',  # South Florida coast
    'p014r041', 'p014r042', 'p014r043',  # East Florida coast
    'p013r034', 'p013r035', 'p013r036',  # North Carolina Outer Banks
    'p011r032',  # Rhode Island coast
    'p010r030',  # Maine
    # Caribbean tiles
    'p014r041', 'p014r042', 'p013r041', 'p013r042',  # Bahamas
    'p012r042', 'p012r043', 'p011r042', 'p011r043',  # Bahamas
    'p013r043',  # Bahamas (main island)
    'p006r037', 'p006r038',  # Bermuda
    'p017r044', 'p016r044', 'p015r044', 'p014r044',  # Cuba
    'p013r044', 'p012r044', 'p011r044', 'p010r044',  # Cuba/Bahamas
]

wrs2_list = sorted(
    ee.FeatureCollection('projects/openet/assets/features/wrs2/custom')
    #.filterBounds(ee.Geometry.BBox(-125.5, 25, -65.5, 52))
    .filterBounds(ee.Geometry.BBox(-127, 24, -63, 52))
    .filter(ee.Filter.inList('wrs2_tile', wrs2_skip_list).Not())
    .aggregate_histogram('wrs2_tile').keys().getInfo(),
    reverse=True
)
print(len(wrs2_list))


In [None]:
# def covering_grid(image, scale):
#     # TODO: Get the offsets from the image transform some how?
#     x_offset = 15
#     y_offset = 15
#     # Get the upper left coordinates from the transform
#     image_geom = image.select([0]).geometry()
#     image_crs = image.select([0]).projection().crs()
#     xy = ee.Array(image_geom.bounds(1, image_crs).coordinates().get(0)).transpose().toList()
#     # Compute the max and min X and Y values for defining the extent
#     xmin = ee.Number(ee.List(xy.get(0)).reduce(ee.Reducer.min()))
#     ymin = ee.Number(ee.List(xy.get(1)).reduce(ee.Reducer.min()))
#     xmax = ee.Number(ee.List(xy.get(0)).reduce(ee.Reducer.max()))
#     ymax = ee.Number(ee.List(xy.get(1)).reduce(ee.Reducer.max()))
#     # Adjust the extent parameters to be buffered and snapped to the Landsat grid (15, 15) and gridsize
#     xmin = xmin.subtract(x_offset).divide(scale).floor().multiply(scale).add(x_offset)
#     ymin = ymin.subtract(y_offset).divide(scale).floor().multiply(scale).add(y_offset)
#     xmax = xmax.subtract(x_offset).divide(scale).ceil().multiply(scale).add(x_offset)
#     ymax = ymax.subtract(y_offset).divide(scale).ceil().multiply(scale).add(y_offset)
#     # Compute the number of columns and rows needed to generate the grid
#     # Subtract 1 since these are the lower left indices
#     num_cols = xmax.subtract(xmin).abs().divide(scale).subtract(1)
#     num_rows = ymax.subtract(ymin).abs().divide(scale).subtract(1)
#     # Build the list of column and row lower left coordinates
#     cols = ee.List.sequence(xmin, num_cols.multiply(scale).add(xmin), scale) 
#     rows = ee.List.sequence(ymin, num_rows.multiply(scale).add(ymin), scale)
    
#     # Build the grid feature collection
#     def create_grid_coll(c):
#         all_fts = ee.FeatureCollection([]) 
#         c_tag = ee.List(cols).indexOf(c)
#         def build_row(r):
#             cell_geom = ee.Geometry.Rectangle(
#                 ee.List([ee.Number(c), ee.Number(r), ee.Number(c).add(scale), ee.Number(r).add(scale)]),
#                 image_crs, False
#             )
#             return ee.Feature(cell_geom, {'col': c_tag, 'row': ee.List(rows).indexOf(r)})
#         row_fts = ee.FeatureCollection(rows.map(build_row))
        
#         all_fts = all_fts.merge(row_fts)
#         return all_fts
    
#     return ee.FeatureCollection(cols.map(create_grid_coll)).flatten()

# # landsat_id = 'LANDSAT/LC09/C02/T1_L2/LC09_029027_20240413'
# # pprint.pprint(covering_grid(ee.Image(landsat_id).select(['QA_PIXEL']), scale=10000).getInfo())


In [None]:
refl_sr_bands = ['SR_RED', 'SR_GREEN', 'SR_BLUE', 'QA_PIXEL', 'QA_RADSAT']
refl_sr_bands_dict = ee.Dictionary({
    'LT04': ['SR_B3', 'SR_B2', 'SR_B1', 'QA_PIXEL', 'QA_RADSAT'],
    'LT05': ['SR_B3', 'SR_B2', 'SR_B1', 'QA_PIXEL', 'QA_RADSAT'],
    'LE07': ['SR_B3', 'SR_B2', 'SR_B1', 'QA_PIXEL', 'QA_RADSAT'],
    'LC08': ['SR_B4', 'SR_B3', 'SR_B2', 'QA_PIXEL', 'QA_RADSAT'],
    'LC09': ['SR_B4', 'SR_B3', 'SR_B2', 'QA_PIXEL', 'QA_RADSAT'],
})
refl_toa_bands_dict = ee.Dictionary({
    'LT04': ['B3', 'B2', 'B1'],
    'LT05': ['B3', 'B2', 'B1'],
    'LE07': ['B3', 'B2', 'B1'],
    'LC08': ['B4', 'B3', 'B2'],
    'LC09': ['B4', 'B3', 'B2'],
})
# sr_coll_dict = ee.Dictionary({
#     'LT05': ee.ImageCollection('LANDSAT/LT05/C02/T1_L2'),
#     'LE07': ee.ImageCollection('LANDSAT/LE07/C02/T1_L2'),
#     'LC08': ee.ImageCollection('LANDSAT/LC08/C02/T1_L2'),
#     'LC09': ee.ImageCollection('LANDSAT/LC09/C02/T1_L2'),
# })
# toa_coll_dict = ee.Dictionary({
#     'LT05': ee.ImageCollection('LANDSAT/LT05/C02/T1_TOA'),
#     'LE07': ee.ImageCollection('LANDSAT/LE07/C02/T1_TOA'),
#     'LC08': ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA'),
#     'LC09': ee.ImageCollection('LANDSAT/LC09/C02/T1_TOA'),
# })


In [None]:
def image_stats(landsat_img):
    scene_id = ee.String(landsat_img.get('scene_id'))
    landsat_type = scene_id.slice(0, 4)

    # Note, we can't rename the TOA bands here since the simple cloud score is expecting a raw TOA image
    landsat_toa_img = ee.Image(landsat_img.get('landsat_toa_img'))

    default_stats = ee.Dictionary({
        'SCENE_ID': scene_id,
        'UNMASKED_PIXELS': -1, 'TOTAL_PIXELS': -1,
        'CLOUD_PIXELS': -1, 'CIRRUS_PIXELS': -1, 'DILATE_PIXELS': -1, 
        'SHADOW_PIXELS': -1, 'SNOW_PIXELS': -1, 'WATER_PIXELS': -1,
        'SR_RED': -1, 'SR_GREEN': -1, 'SR_BLUE': -1,
        'UNMASKED_SR_RED': -1, 'UNMASKED_SR_GREEN': -1, 'UNMASKED_SR_BLUE': -1,
        'TOA_RED': -1, 'TOA_GREEN': -1, 'TOA_BLUE': -1,
        'UNMASKED_TOA_RED': -1, 'UNMASKED_TOA_GREEN': -1, 'UNMASKED_TOA_BLUE': -1,
        'CLOUD_COVER_LAND': landsat_img.get('CLOUD_COVER_LAND'), 
        'MORAN_1K': -1, 'MORAN_2K': -1, 'MORAN_4K': -1, 'MORAN_8K': -1,
        #'SSEBOP_ETF_count': -9999, 'SSEBOP_ETF_mean': -9999, 
    })

    # Get the cloud mask (including the snow mask for now)
    qa_img = ee.Image(landsat_img.select(['QA_PIXEL']))
    cloud_mask = qa_img.rightShift(3).bitwiseAnd(1).neq(0)
    fmask_mask = qa_img.rightShift(3).bitwiseAnd(1).neq(0)
    
    # if cirrus_flag:
    cirrus_mask = qa_img.rightShift(2).bitwiseAnd(1).neq(0).And(fmask_mask.Not())
    fmask_mask = fmask_mask.Or(cirrus_mask)
    
    # if dilate_flag:
    dilate_mask = qa_img.rightShift(1).bitwiseAnd(1).neq(0).And(fmask_mask.Not())
    fmask_mask = fmask_mask.Or(dilate_mask)

    # if shadow_flag:
    shadow_mask = qa_img.rightShift(4).bitwiseAnd(1).neq(0).And(fmask_mask.Not())
    fmask_mask = fmask_mask.Or(shadow_mask)
    
    # if snow_flag:
    snow_mask = qa_img.rightShift(5).bitwiseAnd(1).neq(0).And(fmask_mask.Not())
    fmask_mask = fmask_mask.Or(snow_mask)

    # if water_flag:
    water_mask = qa_img.rightShift(7).bitwiseAnd(1).neq(0).And(fmask_mask.Not())

    # CGM - This isn't working correctly, don't apply
    # # # Apply a small erosion/dilation
    # # fmask_mask = (
    # #     fmask_mask
    # #     .reduceNeighborhood(ee.Reducer.min(), ee.Kernel.circle(radius=1, units='pixels'))
    # #     .reduceNeighborhood(ee.Reducer.max(), ee.Kernel.circle(radius=2, units='pixels'))
    # #     # .reduceNeighborhood(ee.Reducer.min(), ee.Kernel.circle(radius=30, units='meters'))
    # #     # .reduceNeighborhood(ee.Reducer.max(), ee.Kernel.circle(radius=60, units='meters'))
    # #     # .reproject(qa_img.projection())
    # # )

    # Saturated mask (only keep unmasked saturated pixels)
    # Flag as saturated if any of the RGB bands are saturated
    #   or change .gt(0) to .gt(7) to flag if all RGB bands are saturated
    # Comment out rightShift line to flag if saturated in any band
    bitshift = ee.Dictionary({'LT04': 0, 'LT05': 0, 'LE07': 0, 'LC08': 1, 'LC09': 1});
    saturated_mask = (
        landsat_img.select('QA_RADSAT')
        .rightShift(ee.Number(bitshift.get(landsat_type))).bitwiseAnd(7)
        .gt(0)
    )
    saturated_mask = saturated_mask.where(fmask_mask, 0)

    # Simple cloud score (ACCA)
    # Only keep unmasked ACCA pixels
    acca_mask = ee.Algorithms.Landsat.simpleCloudScore(landsat_toa_img).select(['cloud']).gte(100)
    acca_mask = acca_mask.where(fmask_mask, 0)

    # Flip to set cloudy pixels to 0 and clear to 1
    fmask_update_mask = fmask_mask.Not()

    rr_mean_params = {
        'reducer': ee.Reducer.mean().unweighted(),
        'geometry': qa_img.geometry(), 
        'crs': qa_img.projection().crs(), 
        'crsTransform': [30, 0, 15, 0, -30, 15],
        'bestEffort': False,
        'maxPixels': 1E12,
    }
    rr_count_params = {
        'reducer': ee.Reducer.count().unweighted(),
        'geometry': qa_img.geometry(), 
        'crs': qa_img.projection().crs(), 
        'crsTransform': [30, 0, 15, 0, -30, 15],
        'bestEffort': False,
        'maxPixels': 1E12,
    }
    
    tile_scale = 1
    if tile_scale != 1:
        rr_mean_params['tileScale'] = tile_scale
        rr_count_params['tileScale'] = tile_scale

    refl_sr_nomask_bands = (
        landsat_img.select(['SR_RED', 'SR_GREEN', 'SR_BLUE'])
        .multiply([0.0000275]).add([-0.2]).clamp(0, 1)
    )
    refl_sr_masked_bands = (
        landsat_img.select(
            ['SR_RED', 'SR_GREEN', 'SR_BLUE'], 
            ['UNMASKED_SR_RED', 'UNMASKED_SR_GREEN', 'UNMASKED_SR_BLUE']
        )
        .multiply([0.0000275]).add([-0.2]).clamp(0, 1)
        .updateMask(fmask_update_mask)
    )
    refl_toa_nomask_bands = (
        landsat_toa_img.select(
            refl_toa_bands_dict.get(landsat_type), 
            ['TOA_RED', 'TOA_GREEN', 'TOA_BLUE']
        )
    )
    refl_toa_masked_bands = (
        landsat_toa_img.select(
            refl_toa_bands_dict.get(landsat_type), 
            ['UNMASKED_TOA_RED', 'UNMASKED_TOA_GREEN', 'UNMASKED_TOA_BLUE']
        )
        .updateMask(fmask_update_mask)
    )
    refl_mean_stats = (
        refl_sr_nomask_bands
        .addBands(refl_sr_masked_bands)
        .addBands(refl_toa_nomask_bands)
        .addBands(refl_toa_masked_bands)
        .updateMask(land_mask)
        .reduceRegion(**rr_mean_params)
    )

    # Compute the masked count stats (these may be the same, not sure yet)
    # If they are, then it may make more sense to compute the masked and unmasked count
    count_stats = (
        landsat_img.select(['SR_RED'], ['UNMASKED_PIXELS']).updateMask(fmask_update_mask)
        .addBands([
            landsat_img.select(['SR_RED'], ['TOTAL_PIXELS']),
            cloud_mask.selfMask().rename(['CLOUD_PIXELS']),
            cirrus_mask.selfMask().rename(['CIRRUS_PIXELS']),
            dilate_mask.selfMask().rename(['DILATE_PIXELS']),
            shadow_mask.selfMask().rename(['SHADOW_PIXELS']),
            snow_mask.selfMask().rename(['SNOW_PIXELS']),
            water_mask.selfMask().rename(['WATER_PIXELS']),
            saturated_mask.selfMask().rename(['SATURATED_PIXELS']),
            acca_mask.selfMask().rename(['ACCA_PIXELS']),
        ])
        .updateMask(land_mask)
        .reduceRegion(**rr_count_params)
    )

    # # Compute the SSEBop ETf stats
    # ssebop_stats = (
    #     ee.Image(landsat_img).select('SSEBOP_ETF').divide(10000)
    #     .updateMask(land_mask)
    #     .reduceRegion(
    #         reducer=ee.Reducer.mean().unweighted()
    #             .combine(ee.Reducer.count().unweighted(), '', True)
    #             # .setOutputs(['SSEBOP_ETF_MEAN', 'SSEBOP_ETF_COUNT'])
    #         ,
    #         geometry=qa_img.geometry(), 
    #         crs=qa_img.projection().crs(), 
    #         crsTransform=[30, 0, 15, 0, -30, 15],
    #         bestEffort=False,
    #         maxPixels=1E10,
    #     )
    # )

    # # Compute the grid stats for the default cloud mask
    # mask_img = fmask_mask.updateMask(land_mask)
    # # image_url = (
    # #     grid_mask_img.visualize(min=0, max=1, palette=['orange', 'blue'])
    # #     .getThumbURL({'region': mask_img.geometry().bounds(1, 'EPSG:4326'), 'dimensions': 1024})
    # # )
    # # ipyplot.plot_images([image_url], img_width=1024)

    # # First build the covering grid
    # # 30, 60, 120, 240, 480, 960, 1920, 3840, 7680, 15360, 30720, 61140, 122880
    # grid_coll = covering_grid(qa_img, 7680)
    # # grid_coll = qa_img.geometry().coveringGrid(qa_img.projection().crs(), 7680)
    # # grid_coll = qa_img.geometry().coveringGrid(qa_img.projection().crs(), 15360)

    # # Compute the scene percent cloud
    # # We may already have this value in this function but recompute for now
    # image_masked_pct = ee.Number(mask_img.rename('masked_pct').reduceRegion(**rr_mean_params).get('masked_pct'))
    # #print(image_cloud_pct.getInfo())
    
    # # Then compute the stats for each grid cell
    # def grid_stats_func(ftr):
    #     stats = mask_img.addBands(mask_img.mask()).rename(['masked', 'total']).reduceRegion(**{
    #         'reducer': ee.Reducer.mean().unweighted(), 'geometry': ftr.geometry(), 
    #         'crs': qa_img.projection().crs(), 'crsTransform': [30, 0, 15, 0, -30, 15], 
    #         'bestEffort': False, 'maxPixels': 1E12,
    #     })

    #     stats = stats.set('masked_dec', ee.Number(stats.get('masked')).multiply(10).round().divide(10))
        
    #     # Start with a dictionary set to 0 to handle the grid cells that don't intersect the image
    #     stats = ee.Dictionary({'masked': 0, 'total': 0}).combine(stats, True)
        
    #     # Compute the difference in cloudiness from the scene average
    #     #stats = stats.set('diff', ee.Number(stats.get('masked')).subtract(image_masked_pct))
        
    #     return ee.Feature(ftr.geometry(), stats)

    # # Only keep grid cells that cover the unmasked portion of the scene by some amount
    # min_total = 0.5
    # grid_coll = ee.FeatureCollection(grid_coll.map(grid_stats_func)).filter(ee.Filter.gt('total', min_total))
    # grid_stats = grid_coll.aggregate_stats('diff')
    # grid_histogram
    # grid_stats = ee.Dictionary({
    #     'GRID_COUNT': grid_stats.get('total_count'),
    #     'GRID_MASK_PCT': image_masked_pct,
    #     # 'GRID_COUNT': grid_stats.get('valid_count'),
    #     # 'GRID_DIFF_MEAN': grid_stats.get('mean'),
    #     # 'GRID_DIFF_MIN': grid_stats.get('min'),
    #     # 'GRID_DIFF_MAX': grid_stats.get('max'),
    #     # 'GRID_DIFF_SD': grid_stats.get('total_sd'),
    #     # 'GRID_DIFF_VAR': grid_stats.get('total_var'),
    #     # 'GRID_DIFF_SUMSQ': grid_stats.get('sum_sq'),
    # })

    output_stats = (
        default_stats
        .combine(refl_mean_stats, overwrite=True)
        .combine(count_stats, overwrite=True)
    )
    
    return ee.Feature(None, output_stats)


# image_id = 'LANDSAT/LC08/C02/T1_L2/LC08_019037_20200911'  # random
# # image_id = 'LANDSAT/LC09/C02/T1_L2/LC09_023034_20220703'  # clustered
# # image_id = 'LANDSAT/LC09/C02/T1_L2/LC09_029027_20240328'  # clustered
# test_img = (
#     ee.Image(image_id)
#     .select(refl_sr_bands_dict.get('LE07'), ['SR_RED', 'SR_GREEN', 'SR_BLUE', 'QA_PIXEL', 'QA_RADSAT'])
#     .set('scene_id', image_id.split('/')[-1])
#     .set('landsat_toa_img', ee.Image(image_id.replace('T1_L2', 'T1_TOA')))
# )
# output = image_stats(test_img)
# pprint.pprint(output.getInfo())


In [None]:
# Compute stats and save CSV to bucket
#start_year = 1984
#end_year = 2023
start_year = 2024
end_year = 2025
years = list(range(start_year, end_year+1))
delay = 1

overwrite_flag = True
local_files = os.listdir(stats_ws)

print('\nIterating by year and wrs2')             
for year in years: 
    
    print(f'{year} - reading bucket files')
    bucket_object = storage_client.get_bucket(BUCKET_NAME)
    bucket_files = {os.path.basename(x.name) for x in bucket_object.list_blobs(prefix=f'{BUCKET_FOLDER}/{year}')}

    start_date = ee.Date.fromYMD(year, 1, 1)
    end_date = start_date.advance(1, 'year')
        
    for wrs2_tile in wrs2_list:
        
        # if wrs2_skip_list and wrs2_tile in wrs2_skip_list:
        #     # print(f'{wrs2_tile} - wrs2 in skip list exists, skipping')
        #     continue
        # if int(wrs2_tile[1:4]) != 20:
        #     continue
        # if int(wrs2_tile[1:4]) not in range(10, 45):
        #     continue

        if f'{wrs2_tile}_{year}.csv' in bucket_files:
            if not overwrite_flag:
                # print(f'{wrs2_tile} - bucket csv exists and overwrite is False, skipping')
                continue
            else:
                print(f'  removing csv from bucket')
                blob = bucket_object.blob(f'{BUCKET_FOLDER}/{year}/{wrs2_tile}_{year}.csv')
                blob.delete() 
        # # I'm not sure that checking the local file is all that helpful 
        # if (f'{wrs2_tile}_{year}.csv' in stats_files) and not overwrite_flag:
        #     # print(f'{wrs2_tile} - local csv exists and overwrite is False, skipping')
        #     continue
        
        print(f'{wrs2_tile}')
        
        l4_sr_coll = (
            ee.ImageCollection('LANDSAT/LT04/C02/T1_L2')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .select(refl_sr_bands_dict.get('LT04'), refl_sr_bands)
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l5_sr_coll = (
            ee.ImageCollection('LANDSAT/LT05/C02/T1_L2')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .select(refl_sr_bands_dict.get('LT05'), refl_sr_bands)
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l7_sr_coll = (
            ee.ImageCollection('LANDSAT/LE07/C02/T1_L2')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .select(refl_sr_bands_dict.get('LE07'), refl_sr_bands)
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l8_sr_coll = (
            ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .select(refl_sr_bands_dict.get('LC08'), refl_sr_bands)
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l9_sr_coll = (
            ee.ImageCollection('LANDSAT/LC09/C02/T1_L2')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .select(refl_sr_bands_dict.get('LC09'), refl_sr_bands)
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        if year < 1993:
            landsat_sr_coll = l5_sr_coll.merge(l4_sr_coll)
        elif year in range(1993, 1999):
            landsat_sr_coll = l5_sr_coll
        elif year in range(1999, 2013):
            landsat_sr_coll = l5_sr_coll.merge(l7_sr_coll)
        elif year in range(2013, 2023):
            landsat_sr_coll = l8_sr_coll.merge(l7_sr_coll)
        elif year >= 2023:
            landsat_sr_coll = l9_sr_coll.merge(l8_sr_coll)

        if landsat_sr_coll.size().getInfo() == 0:
            print('  no landsat sr images in year/tile')
            continue

        l4_toa_coll = (
            ee.ImageCollection('LANDSAT/LT04/C02/T1_TOA')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l5_toa_coll = (
            ee.ImageCollection('LANDSAT/LT05/C02/T1_TOA')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l7_toa_coll = (
            ee.ImageCollection('LANDSAT/LE07/C02/T1_TOA')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l8_toa_coll = (
            ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )
        l9_toa_coll = (
            ee.ImageCollection('LANDSAT/LC09/C02/T1_TOA')
            .filterDate(start_date, end_date)
            .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
            .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
            .map(lambda img: img.set({'scene_id': img.get('system:index')}))
            # .map(lambda img: img.set('image_id', img.get('system:id')))
        )

        if year < 1993:
            landsat_toa_coll = l5_toa_coll.merge(l4_toa_coll)
        elif year in range(1993, 1999):
            landsat_toa_coll = l5_toa_coll
        elif year in range(1999, 2013):
            landsat_toa_coll = l5_toa_coll.merge(l7_toa_coll)
        elif year in range(2013, 2023):
            landsat_toa_coll = l8_toa_coll.merge(l7_toa_coll)
        elif year >= 2023:
            landsat_toa_coll = l9_toa_coll.merge(l8_toa_coll)

        landsat_coll = ee.Join.saveFirst(matchKey='landsat_toa_img').apply(
            landsat_sr_coll, landsat_toa_coll, 
            ee.Filter.equals(leftField='scene_id', rightField='scene_id'),
        )

        # CGM - Not computing the SSEBop ETf stats for now
        # # if year >= 2016:
        # #     ssebop_etf_coll_id = 'projects/openet/assets/ssebop/conus/gridmet/landsat/c02'
        # # else:
        # #     ssebop_etf_coll_id = 'projects/usgs-gee-nhm-ssebop/assets/ssebop/landsat/c02'
        # ssebop_etf_coll_id = 'projects/usgs-gee-nhm-ssebop/assets/ssebop/landsat/c02'
        # ssebop_etf_coll = (
        #     ee.ImageCollection(ssebop_etf_coll_id)
        #     .filterDate(f'{year}-01-01', f'{year+1}-01-01')
        #     .filterMetadata('wrs2_tile', 'equals', wrs2_tile)
        #     .select(['et_fraction'], ['SSEBOP_ETF'])
        # )
        # join_coll = ee.ImageCollection(landsat_coll).linkCollection(
        #     ssebop_etf_coll, linkedBands=['SSEBOP_ETF'], matchPropertyName='scene_id'
        # )
        # output_coll = join_coll.map(image_stats)

        # Compute the image statistics
        output_coll = ee.ImageCollection(landsat_coll).map(image_stats)
        
        print('  Starting export task')
        task = ee.batch.Export.table.toCloudStorage(
            output_coll, 
            f'{wrs2_tile}_{year}_scene_stats', 
            bucket=BUCKET_NAME, 
            fileNamePrefix=f'{BUCKET_FOLDER}/{year}/{wrs2_tile}_{year}', 
        )
        task.start()
        # print(task.status()['id'])
    
        time.sleep(delay)

print('\nDone')


In [None]:
# # Clean up the CSV files to remove unneeded columns
# start_year = 2024
# end_year = 2024

# for year in range(start_year, end_year + 1):
#     print(f'{year}')
#     for wrs2_tile in sorted(wrs2_list):        
#         wrs2_stats_path = os.path.join(stats_ws, f'{year}', f'{wrs2_tile}_{year}.csv')
#         if not os.path.isfile(wrs2_stats_path):
#             # print(f'  {wrs2_tile}_{year} - Missing stats CSV, skipping')
#             continue
#         # print(f'  {wrs2_tile}')
            
#         try:
#             wrs2_stats_df = pd.read_csv(wrs2_stats_path)
#             # wrs2_stats_df.reset_index(drop=True, inplace=True)
#         except Exception as e:
#             print(f'  {wrs2_tile}_{year} - Error reading CSV, skipping')
#             continue
#         if wrs2_stats_df.empty:
#             continue
            
#         for k in ['system:index', '.geo', 'Unnamed: 0', 'Unnamed: 0.1', 'Unnamed: 0.2', 'Unnamed: 0.3', 'Unnamed: 0.4', 'DATE']:
#             try:
#                 wrs2_stats_df.drop(columns=[k], inplace=True)
#             except:
#                 pass
#         if any(['unnamed' in c.lower() for c in wrs2_stats_df.columns]):
#             print(f'{wrs2_tile}_{year}.csv')
#             print(wrs2_stats_df.columns)
#             input('ENTER')

#         # if 'DATE' not in wrs2_stats_df.columns:
#         #     wrs2_stats_df['DATE'] = wrs2_stats_df['SCENE_ID'].str.slice(12, 20)
#         if 'WRS2' not in wrs2_stats_df.columns:
#             wrs2_stats_df['WRS2'] = 'p' + wrs2_stats_df['SCENE_ID'].str.slice(5, 8) + 'r' + wrs2_stats_df['SCENE_ID'].str.slice(8, 11)

#         # Force the Moran columns to a float type
#         for k in ['MORAN_1K', 'MORAN_2K', 'MORAN_4K', 'MORAN_8K']:
#             if wrs2_stats_df[k].dtype == 'int64':
#                 wrs2_stats_df[k] = wrs2_stats_df[k].astype(float)
            
#         # print(f'  {wrs2_stats_path.split("/")[-1]}')
#         wrs2_stats_df.to_csv(wrs2_stats_path, index=False)


In [None]:
# # 
# def cloud_mask_func(landsat_img):
#     scene_id = ee.String(ee.Image(landsat_img).get('scene_id'))
#     print(scene_id.getInfo())
#     landsat_type = scene_id.slice(0, 4)
    
#     default_stats = ee.Dictionary({
#         'MASKED_COUNT': -1, 'NOMASK_COUNT': -1,
#         'SR_NOMASK_RED': -1, 'SR_NOMASK_GREEN': -1, 'SR_NOMASK_BLUE': -1,
#         'SR_MASKED_RED': -1, 'SR_MASKED_GREEN': -1, 'SR_MASKED_BLUE': -1,
#         'CLOUD_COVER_LAND': landsat_img.get('CLOUD_COVER_LAND'),  
#         'SCENE_ID': scene_id,
#     })

#     # Get the cloud mask (including the snow mask for now)
#     qa_img = ee.Image(landsat_img.select(['QA_PIXEL']))
#     cloud_mask = qa_img.rightShift(3).bitwiseAnd(1).neq(0)
    
#     # if cirrus_flag:
#     cirrus_mask = qa_img.rightShift(2).bitwiseAnd(1).neq(0)
#     cloud_mask = cloud_mask.Or(cirrus_mask)
    
#     # if dilate_flag:
#     dilate_mask = qa_img.rightShift(1).bitwiseAnd(1).neq(0)
#     cloud_mask = cloud_mask.Or(dilate_mask)

#     # if shadow_flag:
#     shadow_mask = qa_img.rightShift(4).bitwiseAnd(1).neq(0)
#     cloud_mask = cloud_mask.Or(shadow_mask)
    
#     # if snow_flag:
#     snow_mask = qa_img.rightShift(5).bitwiseAnd(1).neq(0)
#     cloud_mask = cloud_mask.Or(snow_mask)

#     # if water_flag:
#     # water_mask = qa_img.rightShift(7).bitwiseAnd(1).neq(0)
#     # cloud_mask = cloud_mask.Or(water_mask)
    
#     # if saturated_flag:
#     #     # Masking if saturated in any band
#     #     sat_mask = input_img.select(['QA_RADSAT']).gt(0)
#     #     cloud_mask = cloud_mask.Or(sat_mask)

#     # # Apply a small erosion/dilation
#     # cloud_mask = (
#     #     cloud_mask
#     #     .reduceNeighborhood(ee.Reducer.min(), ee.Kernel.circle(radius=30, units='meters'))
#     #     .reduceNeighborhood(ee.Reducer.max(), ee.Kernel.circle(radius=120, units='meters'))
#     #     # .reduceNeighborhood(ee.Reducer.min(), ee.Kernel.circle(radius=1, units='pixels'))
#     #     # .reduceNeighborhood(ee.Reducer.max(), ee.Kernel.circle(radius=2, units='pixels'))
#     #     # .reproject(qa_img.projection())
#     # )

#     # Flip to set cloudy pixels to 0 and clear to 1
#     cloud_mask = cloud_mask.Not()       

#     return cloud_mask


# refl_sr_bands_dict = ee.Dictionary({
#     'LT05': ['SR_B3', 'SR_B2', 'SR_B1', 'QA_PIXEL'],
#     'LE07': ['SR_B3', 'SR_B2', 'SR_B1', 'QA_PIXEL'],
#     'LC08': ['SR_B4', 'SR_B3', 'SR_B2', 'QA_PIXEL'],
#     'LC09': ['SR_B4', 'SR_B3', 'SR_B2', 'QA_PIXEL'],
# })


# wrs2_tile = 'p048r027'
# year = 2024

# # l8_sr_coll = (
# # landsat_coll = (
# #     ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
# #     .filterDate(f'{year}-02-01', f'{year+1}-01-01')
# #     .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
# #     .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
# #     .select(refl_sr_bands_dict.get('LC08'), ['SR_RED', 'SR_GREEN', 'SR_BLUE', 'QA_PIXEL'])
# #     .map(lambda img: img.set('scene_id', img.get('system:index')))
# # )
# # l9_sr_coll = (
# landsat_coll = (
#     ee.ImageCollection('LANDSAT/LC09/C02/T1_L2')
#     .filterDate(f'{year}-02-01', f'{year+1}-01-01')
#     .filterMetadata('WRS_PATH', 'equals', int(wrs2_tile[1:4]))
#     .filterMetadata('WRS_ROW', 'equals', int(wrs2_tile[5:8]))
#     .select(refl_sr_bands_dict.get('LC09'), ['SR_RED', 'SR_GREEN', 'SR_BLUE', 'QA_PIXEL'])
#     .map(lambda img: img.set('scene_id', img.get('system:index')))
# )
# # landsat_coll = l9_sr_coll.merge(l8_sr_coll)

# test_img = ee.Image(landsat_coll.first())
# output_img = cloud_mask_func(test_img)

# landsat_region = test_img.select([0]).geometry().bounds(1, 'EPSG:4326')
# image_size = 768

# image_url = (
#     test_img.select(['SR_RED', 'SR_GREEN', 'SR_BLUE']).multiply([0.0000275, 0.0000275, 0.0000275]).add([-0.2, -0.2, -0.2])
#     .where(land_mask.unmask().eq(0), 0.25)
#     .getThumbURL({'min': 0.0, 'max': 0.30, 'region': landsat_region, 'dimensions': image_size})
# )
# output_url = (
#     ee.Image(output_img)
#     .where(land_mask.unmask().eq(0), 0.80)
#     .getThumbURL({'min': 0, 'max': 1, 'region': landsat_region, 'dimensions': image_size})
# )

# ipyplot.plot_images([image_url, output_url], img_width=image_size)
