In [None]:
import collections
import os
import pprint

import ee
import numpy as np
import pandas as pd

from IPython.display import Image, display
import ipyplot


In [None]:
ee.Initialize(
    project='ee-cmorton',
    opt_url='https://earthengine-highvolume.googleapis.com'
)

In [None]:
start_year = 2015
end_year = 2024
years = list(range(start_year, end_year + 1))

wrs2_skip_list = [
    'p010r027', 
    # Canada scenes (not sure why they are in WRS2 list)
    'p035r025', 'p034r025', 'p033r025', 'p032r025', 'p031r025', 
]

land_mask = ee.Image('projects/openet/assets/features/water_mask').Not()
# Apply the NLCD/NALCMS water mask (anywhere it is water, set the ocean mask 
land_mask = land_mask.where(ee.Image("USGS/NLCD_RELEASES/2020_REL/NALCMS").unmask(18).eq(18), 0)
# land_mask = land_mask.And(ee.Image("USGS/NLCD_RELEASES/2020_REL/NALCMS").unmask(18).neq(18))
# # land_mask = ee.Image('projects/openet/assets/meteorology/conus404/ancillary/land_mask')

# image_size = 1024
image_size = 900
# image_size = 700

rgb_bands = {
    'LT04': ['SR_B3', 'SR_B2', 'SR_B1'],
    'LT05': ['SR_B3', 'SR_B2', 'SR_B1'],
    'LE07': ['SR_B3', 'SR_B2', 'SR_B1'],
    'LC08': ['SR_B4', 'SR_B3', 'SR_B2'],
    'LC09': ['SR_B4', 'SR_B3', 'SR_B2'],
}

# 0 - white, 1 - no fill (green), 2 - shadow (dark blue), 3 - snow (light blue), 4 - cloud (light gray), 5 - water (purple), 6 - ocean mask
fmask_palette = "ffffff, 9effa1, blue, 00aff2, dddddd, purple, bfbfbf"
fmask_max = 6


In [None]:
def fmask(landsat_img):
    # Add the fmask image on top of the true color image
    qa_img = landsat_img.select('QA_PIXEL')
    fill_mask = qa_img.bitwiseAnd(1).neq(0)                  # bits: 0
    dilate_mask = qa_img.rightShift(1).bitwiseAnd(1).neq(0)  # bits: 1
    cirrus_mask = qa_img.rightShift(2).bitwiseAnd(1).neq(0)  # bits: 2
    cloud_mask = qa_img.rightShift(3).bitwiseAnd(1).neq(0)   # bits: 3
    shadow_mask = qa_img.rightShift(4).bitwiseAnd(1).neq(0)  # bits: 4
    snow_mask = qa_img.rightShift(5).bitwiseAnd(1).neq(0)    # bits: 5
    clear_mask = qa_img.rightShift(6).bitwiseAnd(1).neq(0)   # bits: 6
    water_mask = qa_img.rightShift(7).bitwiseAnd(1).neq(0)   # bits: 7
    # cloud_conf = qa_img.rightShift(8).bitwiseAnd(3)          # bits: 8, 9
    # shadow_conf = qa_img.rightShift(10).bitwiseAnd(3)        # bits: 10, 11
    # snow_conf = qa_img.rightShift(12).bitwiseAnd(3)          # bits: 12, 13
    # cirrus_conf = qa_img.rightShift(14).bitwiseAnd(3)        # bits: 14, 15

    # Saturated pixels
    # Flag as saturated if any of the RGB bands are saturated
    #   or change .gt(0) to .gt(7) to flag if all RGB bands are saturated
    # Comment out rightShift line to flag if saturated in any band
    bitshift = ee.Dictionary({'LANDSAT_4': 0, 'LANDSAT_5': 0, 'LANDSAT_7': 0, 'LANDSAT_8': 1, 'LANDSAT_9': 1});
    saturated_mask = (
        landsat_img.select('QA_RADSAT')
        .rightShift(ee.Number(bitshift.get(ee.String(landsat_img.get('SPACECRAFT_ID'))))).bitwiseAnd(7)
        .gt(0)
    )
    
    # Old "Fmask" style image
    fmask_img = (
        qa_img.multiply(0)
        .where(landsat_img.select(['SR_B4']).mask().eq(0), 1)
        # .where(saturated_mask, 6)
        .where(water_mask, 5)
        .where(shadow_mask, 2)
        .where(snow_mask, 3)
        .where(cloud_mask.Or(dilate_mask).Or(cirrus_mask), 4)
        # .add(shadow_mask.multiply(2))
        # .add(snow_mask.multiply(3))
        # .add(cloud_mask.Or(dilate_mask).Or(cirrus_mask).multiply(4))
        # .add(cloud_mask.Or(dilate_mask).multiply(4))
        # .add(cloud_mask.And(cloud_conf).multiply(4))
        # .add(water_mask.multiply(5))
    )
    
    return fmask_img.updateMask(fmask_img.neq(0)).rename(['fmask'])


In [None]:
# Get the list of WRS2 tiles from the SSEBop collection
# etf_coll_id = 'projects/openet/assets/ssebop/conus/gridmet/landsat/c02'
etf_coll_id = 'projects/usgs-gee-nhm-ssebop/assets/ssebop/landsat/c02'
wrs2_list = sorted(
    ee.ImageCollection(etf_coll_id).filterDate('2020-01-01', '2024-01-01')
    .aggregate_histogram('wrs2_tile').keys().getInfo(),
    reverse=True
)

In [None]:
print_count = 100

count_threshold_pct_min = 67
count_threshold_pct_max = 70

subset_df = stats_df[~stats_df['SCENE_ID'].isin(scene_skip_list)]
subset_df = subset_df[~subset_df['SCENE_ID'].isin(scene_cloudscore_list)].copy()
# subset_df = stats_df[stats_df['SCENE_ID'].isin(scene_eemetric_missing_list)].copy()
# subset_df = subset_df[subset_df['WRS2'] == 'p042r035']

subset_df = subset_df[subset_df['WRS2'].str.slice(1,4).astype(int).isin(range(10, 25))].copy()
subset_df = subset_df[subset_df['WRS2'].str.slice(5,8).astype(int).isin(range(30, 40))].copy()

subset_df = subset_df[subset_df['CLOUD_COUNT_RATIO'] < (count_threshold_pct_max / 100)].copy()
subset_df = subset_df[subset_df['CLOUD_COUNT_RATIO'] >= (count_threshold_pct_min / 100)].copy()
subset_df.sort_values('CLOUD_COUNT_RATIO', ascending=True, inplace=True)

# new_skip_scenes = []
new_skip_count = 0
for i, row in subset_df.iterrows():
    scene_id = row["SCENE_ID"].upper()

    wrs2 = scene_id.split('_')[-2]
    wrs2_path = int(wrs2[0:3])
    wrs2_row = int(wrs2[3:6])
    wrs2_tgt = f'{wrs2_path:03d}{wrs2_row:03d}'
    wrs2_above = f'{wrs2_path:03d}{wrs2_row-1:03d}'
    wrs2_below = f'{wrs2_path:03d}{wrs2_row+1:03d}'    

    above_scene_id = scene_id.upper().replace(wrs2_tgt, wrs2_above)
    above_stats_df = stats_df.loc[stats_df['SCENE_ID'] == above_scene_id]
    if len(above_stats_df):
        above_cloud_pct = float(above_stats_df.iloc[0]['CLOUD_COVER_LAND'])
    else:
        above_cloud_pct = None
        
    below_scene_id = scene_id.upper().replace(wrs2_tgt, wrs2_below)
    below_stats_df = stats_df.loc[stats_df['SCENE_ID'] == below_scene_id]
    if len(below_stats_df):
        below_cloud_pct = float(below_stats_df.iloc[0]['CLOUD_COVER_LAND'])
    else:
        below_cloud_pct = None

    # if 'REASON' in scene_skip_df.columns:
    #     reason = str(scene_skip_df.loc[scene_skip_df.index == scene_id].iloc[0]['REASON'])
    #     if 'missing' not in reason.lower():
    #         continue
    #     print('#'*80)
    #     print(f'{row["SCENE_ID"]}  {row["CLOUD_COUNT_RATIO"]}  {reason}')
    # else:
    #     print('#'*80)
    #     print(f'{row["SCENE_ID"]}  {row["CLOUD_COUNT_RATIO"]}')
    # # elif:

    landsat_type = scene_id.split('_')[0].upper()
    landsat_sr_img = ee.Image(f'LANDSAT/{landsat_type}/C02/T1_L2/{scene_id}')
    landsat_toa_img = ee.Image(f'LANDSAT/{landsat_type}/C02/T1_TOA/{scene_id}')
    landsat_region = landsat_sr_img.geometry().bounds(1, 'EPSG:4326')
    landsat_rgb_img = landsat_sr_img.select(rgb_bands[landsat_type]).multiply([0.0000275]).add([-0.2])

    fmask_img = fmask(landsat_sr_img)

    # if 'CLOUDSCORE' in scene_skip_df.columns:
    #     scene_skip_row = scene_skip_df.loc[(scene_skip_df.index == scene_id)]
    #     cloudscore = float(scene_skip_row.iloc[0]['CLOUDSCORE'])
    #     acca_mask = (
    #         ee.Algorithms.Landsat.simpleCloudScore(landsat_toa_img).select(['cloud']).gte(cloudscore)
    #         # .reduceNeighborhood(ee.Reducer.min(), ee.Kernel.circle(radius=2, units='pixels'))
    #         # .reduceNeighborhood(ee.Reducer.max(), ee.Kernel.octagon(radius=10, units='pixels'))
    #         # .reduceNeighborhood(ee.Reducer.min(), ee.Kernel.circle(radius=150, units='meters'))
    #         # .reduceNeighborhood(ee.Reducer.max(), ee.Kernel.octagon(radius=30, units='meters'))
    #         # .setDefaultProjection(landsat_sr_img.select(['QA_PIXEL']).projection())
    #         # .reproject(landsat_sr_img.select(['QA_PIXEL']).projection())
    #     )
    #     cv_mask = ee.Image(str(scene_skip_row.iloc[0]['MASK']))
    #     fmask_img = fmask_img.where(acca_mask.And(cv_mask), 6)
    
    # # Old approach for show unmasked FMask pixels that were flagged by ACCA
    # fmask_img = fmask_img.where(acca_mask.And(fmask_img.lt(2)), 6)

    # Landsat true color image
    image_url = (
        landsat_rgb_img.where(land_mask.unmask().eq(0), 0.25)
        .getThumbURL({'min': 0.0, 'max': 0.30, 'gamma': 1.25, 'region': landsat_region, 'dimensions': image_size})
    )

    # Landsat true color image with Fmask
    fmask_url = (
        landsat_rgb_img.where(land_mask.unmask().eq(0), 0.25).visualize(min=0, max=0.3, gamma=1.25)
        .blend(fmask_img.selfMask().where(land_mask.unmask().eq(0), fmask_max).visualize(bands='fmask', min=0, max=fmask_max, palette=fmask_palette))
        .getThumbURL({'region': landsat_region, 'dimensions': image_size})
    )

    print('#'*80)
    print(
        f'  {scene_id}  {row["TOTAL_PIXELS"]:>10d}  {row["UNMASKED_PIXELS"]:>10d}'
        f'  ({row["CLOUD_COUNT_RATIO"]:>0.2f}) ({row["SNOW_COUNT_RATIO"]:>0.2f}) {row["CLOUD_COVER_LAND"]}'
        # f'  {row[red_band]:0.2f}  {row[green_band]:0.2f}  {row[blue_band]:0.2f}'
    )
    ipyplot.plot_images([image_url, fmask_url], img_width=image_size)

    
    # Show the images above and below the target wrs2
    above_img = ee.Image(f'LANDSAT/{landsat_type}/C02/T1_L2/{above_scene_id}')
    above_region = above_img.geometry().bounds(1, 'EPSG:4326')
    above_sr_img = above_img.select(rgb_bands[landsat_type]).multiply([0.0000275]).add([-0.2])
    try:
        above_url = (
            above_sr_img.where(land_mask.unmask().eq(0), 0.25).visualize(min=0, max=0.3, gamma=1.25)
            .blend(fmask(above_img).selfMask().where(land_mask.unmask().eq(0), fmask_max).visualize(bands='fmask', min=0, max=fmask_max, palette=fmask_palette))
            .getThumbURL({'region': above_region, 'dimensions': image_size})
        )
    except:
        above_url = None
        
    below_img = ee.Image(f'LANDSAT/{landsat_type}/C02/T1_L2/{below_scene_id}')
    below_region = below_img.geometry().bounds(1, 'EPSG:4326')
    below_sr_img = below_img.select(rgb_bands[landsat_type]).multiply([0.0000275]).add([-0.2])
    try:
        below_url = (
            below_sr_img.where(land_mask.unmask().eq(0), 0.25).visualize(min=0, max=0.3, gamma=1.25)
            .blend(fmask(below_img).selfMask().where(land_mask.unmask().eq(0), fmask_max).visualize(bands='fmask', min=0, max=fmask_max, palette=fmask_palette))
            .getThumbURL({'region': below_region, 'dimensions': image_size})
        )
    except:
        below_url = None

    above_skipped = f' (skipped)' if above_scene_id in scene_skip_list else ''   
    below_skipped = f' (skipped)' if below_scene_id in scene_skip_list else ''
    
    if above_url and below_url:
        print(f'{below_scene_id} ({below_cloud_pct}){below_skipped}  {above_scene_id} ({above_cloud_pct}){above_skipped}')
        ipyplot.plot_images([below_url, above_url], img_width=image_size)
    elif above_url:
        print(f'{above_scene_id} ({above_cloud_pct}){above_skipped}')
        ipyplot.plot_images([above_url], img_width=image_size)
    elif below_url:
        print(f'{below_scene_id} ({below_cloud_pct}){below_skipped}')
        ipyplot.plot_images([below_url], img_width=image_size)

    
    # new_skip_scenes.append(row["SCENE_ID"])
    new_skip_count += 1
    if new_skip_count >= print_count:
        break
        # for scene_id in new_skip_scenes:
        #     print(scene_id)
        # new_skip_scenes = []
        # new_skip_count = 0 

    # break

print('\nDone')

# if new_skip_scenes:
#     for scene_id in new_skip_scenes:
#         print(scene_id)
