In [None]:
# ee_ndvi_seasonal_year_per_state.py
# Minimal EE pipeline:
# - mean NDVI (Landsat 8/9 SR)
# - seasonal-year aggregation (DJF uses Dec of previous year)
# - exports ONE STATE PER FILE: (state × season × buffer)

import ee

# ---------------------------
# CONFIG
# ---------------------------
CONFIG = {
    "year": 2022,                  # Seasonal year (DJF = Dec 2021–Feb 2022, etc.)
    "buffers_m": [0],              # e.g., [0, 90, 180]
    "scale_m": 250,                # ~250 m for speed; later set 30
    "tile_scale": 4,               # bump if memory errors
    "drive_folder": "EE_NDVI_Exports",
    "filename_prefix": "ndvi_mean_state",
    # If None → include ALL states & territories; default excludes non-CONUS:
    "exclude_states": ['AK','HI','PR','VI','GU','AS','MP','UM'],
}

# ---------------------------
# EE init
# ---------------------------
def init_ee():
    try:
        ee.Initialize()
    except Exception:
        ee.Authenticate()
        ee.Initialize()

# ---------------------------
# Landsat SR → (red, nir) → NDVI
# ---------------------------
def ls_mask_and_scale(img):
    qa = img.select('QA_PIXEL')
    mask = (qa.bitwiseAnd(1<<3).eq(0)      # cloud shadow
            .And(qa.bitwiseAnd(1<<4).eq(0) # snow
            .And(qa.bitwiseAnd(1<<5).eq(0)))# cloud
    def scale(b):  # USGS C2 L2 reflectance: ρ = DN*0.0000275 - 0.2
        return img.select(b).multiply(0.0000275).add(-0.2)
    red = scale('SR_B4'); nir = scale('SR_B5')
    return ee.Image.cat([red, nir]).rename(['red','nir']).updateMask(mask)

def landsat_sr_collection(start, end):
    l8 = ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
    l9 = ee.ImageCollection("LANDSAT/LC09/C02/T1_L2")
    return (l8.merge(l9)
              .filterDate(start, end)
              .filter(ee.Filter.lt('CLOUD_COVER', 80))
              .map(ls_mask_and_scale))

def mean_ndvi_image(start, end):
    col = landsat_sr_collection(start, end).map(
        lambda im: im.addBands(
            im.select('nir').subtract(im.select('red'))
              .divide(im.select('nir').add(im.select('red')))
              .rename('ndvi')
        )
    ).select('ndvi')
    return col.mean().rename('ndvi')

# ---------------------------
# Seasonal-year windows
# ---------------------------
def seasonal_year_ranges(year):
    djf_start = ee.Date.fromYMD(year - 1, 12, 1); djf_end = ee.Date.fromYMD(year, 3, 1)
    mam_start = ee.Date.fromYMD(year, 3, 1);      mam_end = ee.Date.fromYMD(year, 6, 1)
    jja_start = ee.Date.fromYMD(year, 6, 1);      jja_end = ee.Date.fromYMD(year, 9, 1)
    son_start = ee.Date.fromYMD(year, 9, 1);      son_end = ee.Date.fromYMD(year, 12, 1)
    return [
        (f"{year}_DJF", djf_start, djf_end),
        (f"{year}_MAM", mam_start, mam_end),
        (f"{year}_JJA", jja_start, jja_end),
        (f"{year}_SON", son_start, son_end),
    ]

# ---------------------------
# States + helpers
# ---------------------------
def get_states_fc(exclude_states=None):
    fc = ee.FeatureCollection("TIGER/2018/States")
    if exclude_states is None:
        return fc
    return fc.filter(ee.Filter.inList('STUSPS', exclude_states).Not())

def list_state_codes(base_fc):
    # client-side list of state USPS codes to loop over
    return sorted(base_fc.aggregate_array('STUSPS').getInfo())

def state_feature(base_fc, stusps):
    return base_fc.filter(ee.Filter.eq('STUSPS', stusps)).first()

def buffered_fc_of_one_state(base_fc, stusps, buffer_m):
    f = state_feature(base_fc, stusps)
    fc = ee.FeatureCollection([f])
    if buffer_m == 0:
        return fc.map(lambda x: x.set({'buffer_m': 0}))
    return fc.map(lambda x: x.buffer(buffer_m).set({'buffer_m': buffer_m}))

# ---------------------------
# Reduce & export (per state)
# ---------------------------
def reduce_state(img, regions_fc, period_tag, buffer_m, scale, tile_scale):
    red = img.reduceRegions(
        collection=regions_fc,
        reducer=ee.Reducer.mean(),
        scale=scale,
        tileScale=tile_scale
    )
    return red.map(lambda f: f.set({
        "period": period_tag,
        "stat": "mean",
        "index": "ndvi",
        "buffer_m": buffer_m,
        "scale_m": scale
    }))

def export_table(fc, description, filename_prefix, folder):
    task = ee.batch.Export.table.toDrive(
        collection=fc,
        description=description,
        folder=folder,
        fileNamePrefix=filename_prefix,
        fileFormat='CSV'
    )
    task.start()
    return task

# ---------------------------
# Main (per-state exports)
# ---------------------------
def main():
    init_ee()

    base_fc = get_states_fc(CONFIG.get("exclude_states")).select(["GEOID","NAME","STUSPS"])
    state_codes = list_state_codes(base_fc)  # e.g., ['AL','AR',...,'WY']

    for (period_tag, start, end) in seasonal_year_ranges(CONFIG["year"]):
        img = mean_ndvi_image(start, end)

        for st in state_codes:
            for buffer_m in CONFIG["buffers_m"]:
                regions_fc = buffered_fc_of_one_state(base_fc, st, buffer_m)

                table_fc = reduce_state(
                    img=img,
                    regions_fc=regions_fc,
                    period_tag=period_tag,
                    buffer_m=buffer_m,
                    scale=CONFIG["scale_m"],
                    tile_scale=CONFIG["tile_scale"],
                )

                scale_tag = f"s{int(CONFIG['scale_m'])}m"
                desc  = f"ndvi_mean_{period_tag}_{st}_b{buffer_m}m_{scale_tag}"
                fname = f"{CONFIG['filename_prefix']}_{period_tag}_{st}_b{buffer_m}m_{scale_tag}"

                export_table(table_fc, desc, fname, CONFIG["drive_folder"])
                print("Started export:", desc)

if __name__ == "__main__":
    main()
