In [None]:
import xarray as xr
import pandas as pd
import os
import glob
from typing import List

# ==========================================
# Configuration
# ==========================================
VARS_2D: List[str] = [
    "total_precipitation",
    "volumetric_soil_water_layer_1",
    "volumetric_soil_water_layer_2",
    "2m_temperature",
    "surface_solar_radiation_downwards",
    "evaporation",
]

LAT_CALI = slice(40, 36)
LON_CALI_ERA5 = slice(239, 241)
LON_CALI_LAI = slice(-121, -119)

START_YEAR = 2011
END_YEAR = 2020
SPATIAL_STEP = 5

# ==========================================
# ERA5 Weekly Download
# ==========================================
def download_era5_weekly(region_name, lat_slice, lon_slice, start_year, end_year):
    store = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
    ds = xr.open_dataset(store, engine="zarr", chunks={})

    for year in range(start_year, end_year + 1):
        filename = f"{region_name}_era5_weekly_{year}.csv"
        if os.path.exists(filename):
            continue

        frames = []

        for month in range(1, 13):
            time_sel = f"{year}-{month:02d}"

            subset = ds[VARS_2D].sel(
                latitude=lat_slice,
                longitude=lon_slice,
                time=time_sel,
            )

            subset_4d = ds["specific_humidity"].sel(
                latitude=lat_slice,
                longitude=lon_slice,
                time=time_sel,
            ).isel(level=-1)

            subset = subset.isel(
                latitude=slice(None, None, SPATIAL_STEP),
                longitude=slice(None, None, SPATIAL_STEP),
            )

            subset_4d = subset_4d.isel(
                latitude=slice(None, None, SPATIAL_STEP),
                longitude=slice(None, None, SPATIAL_STEP),
            )

            weekly = (
                xr.merge([subset, subset_4d])
                .mean(dim=["latitude", "longitude"])
                .resample(time="1W")
                .mean()
                .load()
                .to_dataframe()
                .dropna()
            )

            frames.append(weekly)

        if frames:
            pd.concat(frames).to_csv(filename)

    files = glob.glob(f"{region_name}_era5_weekly_*.csv")
    if files:
        df = pd.concat(
            [pd.read_csv(f, index_col=0, parse_dates=True) for f in files]
        ).sort_index()
        df.to_csv(f"{region_name}_era5_full_master_weekly.csv")
        return df

    return None

# ==========================================
# LAI Weekly Download
# ==========================================
def download_lai_weekly(region_name, lat_slice, lon_slice, start_year, end_year):
    filename = f"{region_name}_lai_weekly_{start_year}_{end_year}.csv"
    if os.path.exists(filename):
        return pd.read_csv(filename, index_col=0, parse_dates=True)

    ds = xr.open_dataset(
        "https://nyu1.osn.mghpcc.org/leap-pangeo-pipeline/MODIS_LAI/MODIS_LAI.zarr",
        engine="zarr",
        chunks={},
    )

    df = (
        ds["lai"]
        .sel(
            lat=lat_slice,
            lon=lon_slice,
            time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
        )
        .mean(dim=["lat", "lon"])
        .resample(time="1W")
        .mean()
        .load()
        .to_dataframe()
        .dropna()
    )

    df.columns = ["LAI"]
    df.index.name = "time"
    df.to_csv(filename)

    return df

# ==========================================
# Main Execution
# ==========================================
if __name__ == "__main__":
    df_era5 = download_era5_weekly(
        "california", LAT_CALI, LON_CALI_ERA5, START_YEAR, END_YEAR
    )

    df_lai = download_lai_weekly(
        "california", LAT_CALI, LON_CALI_LAI, START_YEAR, END_YEAR
    )

    if df_era5 is not None and df_lai is not None:
        df_era5.index = pd.to_datetime(df_era5.index).normalize()
        df_lai.index = pd.to_datetime(df_lai.index).normalize()

        df_master = (
            df_era5.join(df_lai[["LAI"]], how="inner")
            .dropna()
            .rename_axis("time")
        )

        df_master.to_csv("master_data_for_all_models_weekly.csv")
