# Imports and functions

In [None]:
import xarray as xr
import warnings
warnings.filterwarnings("ignore", category=xr.SerializationWarning)
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [None]:
from config import FLUX_DATA_PATH, FLUX_METADATA, MICASA_PREPROCESSED_DATA, MERRA_DATA_PATH

In [None]:
from utils.functions import import_flux_metadata, import_site_RMSE_data

In [None]:
from utils.plotting import polyfit1d_and_plot

# Open and combine virtual datasets

In [None]:
# Virtualized dataset path
ref_path = os.path.join(MERRA_DATA_PATH, "virtual_store")

In [None]:
ref_url1 = f"reference::{os.path.join(ref_path, "vstore1.parquet")}"
ref_url2 = f"reference::{os.path.join(ref_path, "vstore2.parquet")}"

In [None]:
ds1 = xr.open_dataset(ref_url1, engine="zarr", consolidated=False)

In [None]:
ds2 = xr.open_dataset(ref_url2, engine="zarr", consolidated=False)

In [None]:
ds2["T2M"]

In [None]:
ds_combined = xr.concat([ds1["T2M"],ds2["T2M"]], dim="time")

In [None]:
# Replace missing values with NaN
ds_combined = ds_combined.where(ds_combined!=999999986991104, np.nan)

In [None]:
# ds_combined

# Import RSME values

In [None]:
df_ANN = import_site_RMSE_data(FLUX_METADATA, '../analysis/RMSE_results_ANN.csv')

In [None]:
df_ANN.head()

# Total 30 year averages

In [None]:
ds_totavg = ds_combined.mean("time")

In [None]:
ds_totavg

### Extract subset of values from xarray dataset

In [None]:
# I should move this function to utils (this could speed up preprocessing?)
def nearest_da(df, da):
    """
    Vectorized extraction (more efficient for large datasets)
    """
    # Create arrays of lat/lon values
    lats = df['lat'].values
    lons = df['lon'].values
    
    # Use xarray's advanced indexing
    selected = da.sel(
        lat=xr.DataArray(lats, dims='points'),
        lon=xr.DataArray(lons, dims='points'),
        method='nearest'
    )
    
    return selected

In [None]:
nearest_da(df_ANN, ds_totavg)

In [None]:
df_ANN_tot = df_ANN.copy()
df_ANN_tot['T2M_avg'] = nearest_da(df_ANN, ds_totavg).values

In [None]:
df_ANN_tot.head()

In [None]:
xlabel = '2-meter Temperature (K)\n30-Yr Average (1991-2021)'

In [None]:
polyfit1d_and_plot(df_ANN_tot, "T2M_avg", "NEE_RMSE", xlabel, "NEE RSME vs MERRA2 T2M");

# Selective Year Average
### (Only average across years for each site from "Years of AmeriFlux data column")

In [None]:
df_ANN = df_ANN.rename(columns = {'Years of AmeriFlux FLUXNET Data' : "YEARS_DATA"})

In [None]:
def parse_years(year_string):
    if pd.isna(year_string): 
        return []
    return [int(year.strip()) for year in year_string.split(',')]

In [None]:
df_ANN["YEARS_DATA"] = df_ANN["YEARS_DATA"].apply(parse_years)

In [None]:
df_ANN.head()

In [None]:
ds_annavgs = ds_combined.groupby("time.year").mean("time")

In [None]:
test_ds = nearest_da(df_ANN, ds_annavgs)

In [None]:
test_ds

### Create a mask to select only years of interest for each lat/lon

In [None]:
mask = np.zeros_like(test_ds, dtype=bool)
mask.shape

In [None]:
for i, yrs in enumerate(df_ANN["YEARS_DATA"]):
    test = np.isin(test_ds.year, yrs)
    mask[np.isin(test_ds.year, yrs), i] = True
    
means = test_ds.where(mask).mean('year')

In [None]:
df_ANN_sub = df_ANN.copy()
df_ANN_sub['T2M_avg'] = means.values

In [None]:
df_ANN_sub.head()

In [None]:
xlabel = '2-meter Temperature (K)\nAveraged Across Years of Available Data'

In [None]:
polyfit1d_and_plot(df_ANN_sub, "T2M_avg", "NEE_RMSE", xlabel, "NEE RSME vs MERRA2 T2M");