In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from cartopy import crs as ccrs 
import cartopy.feature as cfeature
# import hvplot.xarray

In [None]:
# Memory check
!free -h

In [None]:
def test_align_exact(ds1, ds2):
    """
    Test two datasets are exactly aligned
    """
    try:
        ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="exact")
        print("Aligned.")
        return ds1_aligned, ds2_aligned

    except ValueError as e:
        print(e)
        return None

## LPJ

### Daily NPP

In [None]:
lpj_path = "/discover/nobackup/projects/GHGC/LPJ_collaborations/NRT_carbon_budget_const_lu/20250521/S2_RESP_ACCLIM/ncdf_outputs"

In [None]:
lpj_dnpp = f"{lpj_path}/ERA5_S2_RESP_ACCLIM_dnpp.nc"

In [None]:
ds_dnpp = xr.open_dataset(lpj_dnpp,
                         # chunks="auto"
                         )

### Daily Rh

In [None]:
lpj_drh = f"{lpj_path}/ERA5_S2_RESP_ACCLIM_drh.nc"

In [None]:
ds_drh = xr.open_dataset(lpj_drh)
ds_drh

### Combine to calc NEE (NEE = RH - NPP)

In [None]:
ds_aligned1, ds_aligned2 = test_align_exact(ds_drh, ds_dnpp)

In [None]:
ds_lpj_combined = xr.merge([ds_drh, ds_dnpp])

In [None]:
ds_lpj_combined

## MiCASA

In [None]:
micasa_path = "micasa_virtualized/vstore.parquet"

In [None]:
ds_mi = xr.open_dataset(f"reference::{micasa_path}", 
                        engine="zarr",
                        consolidated=False,
                        )
ds_mi

In [None]:
ds_mi["NPP"]

In [None]:
ds_mi_chunk = ds_mi.chunk({'time': 30, 'lat': 900, 'lon': 1800})
ds_mi_chunk

## Align MiCASA and LPJ

In [None]:
# Micasa starts at Jan 2001
ds_lpj_sel = ds_lpj_combined.sel(time=slice("2001", None))
ds_lpj_sel

##### 6 days missing? LPJ does NOT include leap days????

In [None]:
# Downsample to match LPJ
ds_mi_downsample = ds_mi_chunk.coarsen(lat=5,lon=5, boundary="trim").mean() # Downsampling (5x5 aggregation since 0.5°/0.1° = 5)
ds_mi_downsample

In [None]:
test_align_exact(ds_lpj_sel, ds_mi_downsample) # Leap days still not aligned

In [None]:
# Chunking
chunk_config = {'time': 365, 'lat': 900, 'lon': 1800}

In [None]:
# Drop leap days to match LPJ and drop unneeded vars
leap_day_mask = ~((ds_mi_downsample.time.dt.month == 2) & (ds_mi_downsample.time.dt.day == 29))
ds_mi_noleap = ds_mi_downsample.sel(time=leap_day_mask)
# Rechunk after remove leap days
ds_mi_sel = ds_mi_noleap.chunk(chunk_config)[["NEE", "NPP", "Rh"]]
ds_mi_sel

In [None]:
# # Change var names for consistency with micasa
ds_lpj_sel = ds_lpj_sel.rename_dims({"latitude": "lat", "longitude": "lon"})
ds_lpj_sel = ds_lpj_sel.rename_vars({"latitude": "lat", "longitude": "lon"})

# print(ds_lpj_sel.lon.values[:5], ds_lpj_sel.lon.values[-5:])
# print(ds_mi_sel.lon.values[:5], ds_mi_sel.lon.values[-5:])

test_align_exact(ds_lpj_sel, ds_mi_sel) # After renaming they aren't aligned?? 
# # Maybe because before they were diff dimensions so they didn't show up as not aligned

In [None]:
print(np.array_equal(ds_lpj_sel.lat.values, ds_mi_sel.lat.values))
print(np.array_equal(ds_lpj_sel.lon.values, ds_mi_sel.lon.values))

In [None]:
print(ds_mi_sel.lat.values[:5], ds_lpj_sel.lat.values[:5])
type(ds_mi_sel.lat.values[0]), type(ds_lpj_sel.lat.values[0])

In [None]:
# ds_lpj_fix = ds_lpj_sel.copy()
# ds_lpj_fix.coords['lat'] = ds_lpj_sel.lat.astype(ds_mi_sel.lat.dtype)
# ds_lpj_fix.coords['lon'] = ds_lpj_sel.lon.astype(ds_mi_sel.lon.dtype)

In [None]:
# print(np.array_equal(ds_lpj_fix.lat.values, ds_mi_sel.lat.values))

In [None]:
# test_align_exact(ds_lpj_fix, ds_mi_sel) # Still not aligned

In [None]:
# test = (ds_mi_sel.lon.values == ds_lpj_fix.lon.values)
# test[:5] # Still not matching

In [None]:
# Force Micasa to fit lpj???
ds_mi_sel = ds_mi_sel.reindex(lat=ds_lpj_sel.lat, lon=ds_lpj_sel.lon, method='nearest')
ds_lpj_align, ds_mi_align = test_align_exact(ds_lpj_sel, ds_mi_sel)

In [None]:
ds_lpj_align = ds_lpj_align.chunk(chunk_config)
ds_lpj_align

## Comparisons

In [None]:
print(ds_lpj_align["dnpp"].attrs, ds_mi_align["NPP"].attrs, sep="\n\n")

### Means

In [None]:
# Convert and find mean of 2024 (test with only this year)
# For actual plots use time=slice("2022", "2024")
ds_lpj_sub = (ds_lpj_align["dnpp"].sel(time="2024"))/86400
ds_lpj_means = ds_lpj_sub.groupby(ds_lpj_sub.time.dt.season).mean(dim="time")
# ds_lpj_means

ds_mi_sub = ds_mi_align["NPP"].sel(time="2024")
ds_mi_means = ds_mi_sub.groupby(ds_mi_sub.time.dt.season).mean(dim="time")
# ds_mi_means

In [None]:
means = ds_lpj_means - ds_mi_means
means = means.astype("float64")
means

In [None]:
%%time 
# Computing the whole dataset is the same time as slicing by lat/lon
means = means.compute()

In [None]:
# Mask NaNs over ocean, this still didn't fix my nonnull.any issue
mask = means.notnull().any(dim=["season"])
means = means.where(mask,drop=True)
means

#### Plot mean

In [None]:
%matplotlib inline
fig = plt.figure(figsize=(10, 6))
ax = plt.axes(projection=ccrs.PlateCarree())
plot = means.sel(season='MAM').plot.pcolormesh(ax=ax, 
                                                        transform=ccrs.PlateCarree(),
                                                        cmap="RdBu",
                                                       add_colorbar=False,)
cb = plt.colorbar(plot,orientation='horizontal', shrink=0.8)

In [None]:
for season in means.season.values:
    min = means.sel(season=season).min().values
    max = means.sel(season=season).max().values
    
    print(f"{season}: {min:.2e}, {max:.2e}")

In [None]:
# means.var().values

In [None]:
%matplotlib inline
proj = ccrs.PlateCarree()
fig, axs = plt.subplots(2,2, figsize=(12, 8), subplot_kw= {'projection': proj});
for ax, season in zip(axs.flat, means.season.values):
    # print(i, season)
    plot = means.sel(season=season).plot.pcolormesh(ax=ax, 
                                                        transform=ccrs.PlateCarree(),
                                                             vmin=-9e-8,vmax=9e-8,
                                                             
                                                             cmap="RdBu",
                                                       add_colorbar=False,)
    cb = plt.colorbar(plot,orientation='horizontal', shrink=0.8, pad=0.05, 
                      # label="LPJ - MiCASA\nMean NPP difference (kg C m-2 s-1)\n2024 avg"
                              label="Difference of mean NPP (kg C m-2 s-1)\nLPJ-EOSIM — MiCASA",
                     )

### Test output plots

In [None]:
# output_dir = "tests"
# import os
# os.makedirs(output_dir, exist_ok=True)

for season in means.season.values:
    fig, ax = plt.subplots(1,1, figsize=(16, 8), subplot_kw= {'projection': proj})
    
    # print(i, season)
    plot = means.sel(season=season).plot.pcolormesh(ax=ax, 
                                                        transform=ccrs.PlateCarree(),
                                                             vmin=-9e-8,vmax=9e-8,
                                                             cmap="RdBu",
                                                       add_colorbar=False,);
    ax.set_title(f"{season}");
    cb = plt.colorbar(plot,orientation='horizontal', shrink=0.8, pad=0.05, 
                    label="LPJ-EOSIM — MiCASA (kg C m$^{-2}$ s$^{-1}$)\n2024",
                     );
    fig.suptitle("Difference of Mean NPP", x=0.5, y=.92,fontsize=15)
    plt.show()
    break
    # output_filename = f"NPPDiff_{season}.png"
    # output_path = os.path.join(output_dir, output_filename)
    # fig.savefig(output_path)

### Variance

In [None]:
ds_lpj_var = ds_lpj_sub.groupby(ds_lpj_sub.time.dt.season).var(dim="time") # This spits warnings about NaN values but it still works
ds_mi_var = ds_mi_sub.groupby(ds_mi_sub.time.dt.season).var(dim="time")

variance = ds_lpj_var - ds_mi_var
variance = variance.astype("float64")

## Compute (so that we can chop unneeded lat/lons on the ocean)
variance = variance.compute()
mask = variance.notnull().any(dim=["season"])
variance = variance.where(mask, drop=True

In [None]:
for season in variance.season.values:
    min = variance.sel(season=season).min().values
    max = variance.sel(season=season).max().values
    
    print(f"{season}: {min:.2e}, {max:.2e}")

In [None]:
fig = plt.figure(figsize=(10, 6))
ax = plt.axes(projection=ccrs.PlateCarree())
plot = variance.sel(season='JJA').plot.pcolormesh(ax=ax, 
                                                        transform=ccrs.PlateCarree(),
                                                        cmap="RdBu",
                                                       add_colorbar=False,)
cb = plt.colorbar(plot,orientation='horizontal', shrink=0.8)

In [None]:
%matplotlib inline
proj = ccrs.PlateCarree()
fig, axs = plt.subplots(2,2, figsize=(12, 8), subplot_kw= {'projection': proj});
for ax, season in zip(axs.flat, variance.season.values):
    # print(i, season)
    plot = variance.sel(season=season).plot.pcolormesh(ax=ax, 
                                                        transform=ccrs.PlateCarree(),
                                                             # vmin=-9e-8,vmax=9e-8,
                                                             
                                                             cmap="RdBu",
                                                       add_colorbar=False,)
    cb = plt.colorbar(plot,orientation='horizontal', shrink=0.8, pad=0.05, 
                      # label="LPJ - MiCASA\nMean NPP difference (kg C m-2 s-1)\n2024 avg"
                              label="Difference of mean NPP\n(kg C m-2 s-1)\nLPJ-EOSIM — MiCASA",
                     )