In [None]:
import earthaccess
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gc
import rioxarray # Required for swath-to-grid matching
from rasterio.enums import Resampling

# setup and authentication
auth = earthaccess.login()
bounds = {"min_lon": -120.75, "max_lon": -119, "min_lat": 33.5, "max_lat": 34.5}
date_range = ("2024-04-01", "2024-12-31")

# target grid
target_lat = np.arange(bounds['min_lat'], bounds['max_lat'], 0.01)
target_lon = np.arange(bounds['min_lon'], bounds['max_lon'], 0.01)
target_grid = xr.Dataset(coords={'lat': target_lat, 'lon': target_lon})

# sst baseline (MUR L4 is already gridded, so standard interp works)
print("processing sst baseline...")
sst_query = earthaccess.search_data(
    short_name="MUR-JPL-L4-GLOB-v4.1",
    bounding_box=(bounds['min_lon'], bounds['min_lat'], bounds['max_lon'], bounds['max_lat']),
    temporal=date_range
)
ds_sst = xr.open_mfdataset(earthaccess.open(sst_query), chunks={'time': 10}, decode_timedelta=True)

sst_monthly = (ds_sst.analysed_sst - 273.15).sel(
    lat=slice(bounds['min_lat'], bounds['max_lat']),
    lon=slice(bounds['min_lon'], bounds['max_lon'])
).interp_like(target_grid, method="linear").resample(time="1MS").mean().compute()

# agg pace data
def process_pace_monthly(short_name, var_name, group_name):
    print(f"searching for {short_name}...")
    query = earthaccess.search_data(
        short_name=short_name,
        bounding_box=(bounds['min_lon'], bounds['min_lat'], bounds['max_lon'], bounds['max_lat']),
        temporal=date_range
    )
    files = earthaccess.open(query)
    storage = {}

    for f in files:
        try:
            nav = xr.open_dataset(f, group="navigation_data", engine="h5netcdf")
            data = xr.open_dataset(f, group=group_name, engine="h5netcdf")
            
            with xr.open_dataset(f, engine="h5netcdf") as meta:
                t = pd.to_datetime(meta.attrs['time_coverage_start'])
                m_key = t.strftime('%Y-%m')

            if m_key not in storage:
                storage[m_key] = [
                    np.zeros((len(target_lat), len(target_lon))), 
                    np.zeros((len(target_lat), len(target_lon)))
                ]

            # NASA Fix Part 1: Prepare swath with rioxarray
            # We assign coordinates and define the spatial dims for the rioxarray engine
            subset = data[var_name].assign_coords({
                "lat": (("number_of_lines", "pixels_per_line"), nav.latitude.values),
                "lon": (("number_of_lines", "pixels_per_line"), nav.longitude.values)
            })
            subset = subset.rio.set_spatial_dims("pixels_per_line", "number_of_lines")
            subset = subset.rio.write_crs("epsg:4326")

            # NASA Fix Part 2: Reproject using the geolocation array
            # This warps the 2D swath into your 1D target_lat/target_lon grid
            regrid = subset.rio.reproject(
                dst_crs="epsg:4326",
                shape=(len(target_lat), len(target_lon)),
                src_geoloc_array=(subset["lon"], subset["lat"]),
                resample=Resampling.nearest,
                nodata=np.nan
            ).compute()
            
            # The reprojected output uses 'x' and 'y', we ensure values are extracted correctly
            valid_mask = ~np.isnan(regrid.values)
            storage[m_key][0] += np.nan_to_num(regrid.values)
            storage[m_key][1] += valid_mask.astype(float)
            
            del subset, regrid, nav, data
            gc.collect()
                
        except Exception as e:
            print(f"skipping granule: {e}")

    monthly_arrays = []
    for m_key, (r_sum, r_count) in storage.items():
        avg = np.divide(r_sum, r_count, out=np.full_like(r_sum, np.nan), where=r_count > 0)
        da = xr.DataArray(avg, coords=[('lat', target_lat), ('lon', target_lon)], name=var_name)
        da = da.expand_dims(time=[pd.to_datetime(m_key)])
        monthly_arrays.append(da)
        
    if not monthly_arrays:
        raise ValueError(f"no data successfully processed for {short_name}")
        
    return xr.concat(monthly_arrays, dim='time').sortby('time')

# Execute Processing
pace_chl_monthly = process_pace_monthly("PACE_OCI_L2_BGC", "chlor_a", "geophysical_data")
pace_bbp_monthly = process_pace_monthly("PACE_OCI_L2_IOP", "bbp_442", "geophysical_data")

# Saving Synthesis
analysis_ds = xr.Dataset({
    "chlor_a": pace_chl_monthly,
    "bbp_442": pace_bbp_monthly,
    "sst": sst_monthly
})
analysis_ds.to_netcdf("l2_sst_correlation.nc")

# Visualization Logic (Remains the same as your draft)
print("Plotting results...")
# ... [Rest of your plotting code] ...

processing sst baseline...


QUEUEING TASKS | :   0%|          | 0/276 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/276 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/276 [00:00<?, ?it/s]