In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from cartopy import crs as ccrs 
import cartopy.feature as cfeature
import hvplot.xarray


In [None]:
# Memory check
!free -h

In [None]:
def test_align_exact(ds1, ds2):
    """
    Test two datasets are exactly aligned
    """
    try:
        ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="exact")
        print("Aligned.")
        return ds1_aligned, ds2_aligned

    except ValueError as e:
        print(e)
        return None

## LPJ

### Daily NPP

In [None]:
lpj_path = "/discover/nobackup/projects/GHGC/LPJ_collaborations/NRT_carbon_budget_const_lu/20250521/S2_RESP_ACCLIM/ncdf_outputs"

In [None]:
lpj_dnpp = f"{lpj_path}/ERA5_S2_RESP_ACCLIM_dnpp.nc"

In [None]:
ds_dnpp = xr.open_dataset(lpj_dnpp,
                         # chunks="auto"
                         )

### Daily Rh

In [None]:
lpj_drh = f"{lpj_path}/ERA5_S2_RESP_ACCLIM_drh.nc"

In [None]:
ds_drh = xr.open_dataset(lpj_drh)
ds_drh

### Combine to calc NEE (NEE = RH - NPP)

In [None]:
ds_aligned1, ds_aligned2 = test_align_exact(ds_drh, ds_dnpp)

In [None]:
ds_lpj_combined = xr.merge([ds_drh, ds_dnpp])

In [None]:

ds_lpj_combined

In [None]:
# This crashes the node, I need to chunk the data
# ds_combined['dnee'] = ds_combined['drh'] - ds_combined['dnpp']

## MiCASA

In [None]:
micasa_path = "micasa_virtualized/vstore.parquet"

In [None]:
ds_mi = xr.open_dataset(f"reference::{micasa_path}", 
                        engine="zarr",
                        consolidated=False,
                        )
ds_mi

In [None]:
ds_mi_chunk = ds_mi.chunk({'time': 30, 'lat': 900, 'lon': 1800})
ds_mi_chunk

## Align MiCASA and LPJ

In [None]:
# Micasa starts at Jan 2001
ds_lpj_sel = ds_lpj_combined.sel(time=slice("2001", None))
ds_lpj_sel

##### 6 days missing? LPJ does NOT include leap days????

In [None]:
# Downsample to match LPJ
ds_mi_downsample = ds_mi_chunk.coarsen(lat=5,lon=5, boundary="trim").mean() # Downsampling (5x5 aggregation since 0.5°/0.1° = 5)
ds_mi_downsample

In [None]:
test_align_exact(ds_lpj_sel, ds_mi_downsample) # Leap days still not aligned

In [None]:
# Chunking
chunk_config = {'time': 365, 'lat': 900, 'lon': 1800}

In [None]:
# Drop leap days to match LPJ and drop unneeded vars
leap_day_mask = ~((ds_mi_downsample.time.dt.month == 2) & (ds_mi_downsample.time.dt.day == 29))
ds_mi_noleap = ds_mi_downsample.sel(time=leap_day_mask)
# Rechunk after remove leap days
ds_mi_sel = ds_mi_noleap.chunk(chunk_config)[["NEE", "NPP", "Rh"]]
ds_mi_sel

In [None]:
# # Change var names for consistency with micasa
ds_lpj_sel = ds_lpj_sel.rename_dims({"latitude": "lat", "longitude": "lon"})
ds_lpj_sel = ds_lpj_sel.rename_vars({"latitude": "lat", "longitude": "lon"})

# print(ds_lpj_sel.lon.values[:5], ds_lpj_sel.lon.values[-5:])
# print(ds_mi_sel.lon.values[:5], ds_mi_sel.lon.values[-5:])

test_align_exact(ds_lpj_sel, ds_mi_sel) # After renaming they aren't aligned?? 
# # Maybe because before they were diff dimensions so they didn't show up as not aligned

In [None]:
print(np.array_equal(ds_lpj_sel.lat.values, ds_mi_sel.lat.values))
print(np.array_equal(ds_lpj_sel.lon.values, ds_mi_sel.lon.values))

In [None]:
print(ds_mi_sel.lat.values[:5], ds_lpj_sel.lat.values[:5])
type(ds_mi_sel.lat.values[0]), type(ds_lpj_sel.lat.values[0])

In [None]:
# ds_lpj_fix = ds_lpj_sel.copy()
# ds_lpj_fix.coords['lat'] = ds_lpj_sel.lat.astype(ds_mi_sel.lat.dtype)
# ds_lpj_fix.coords['lon'] = ds_lpj_sel.lon.astype(ds_mi_sel.lon.dtype)

In [None]:
# print(np.array_equal(ds_lpj_fix.lat.values, ds_mi_sel.lat.values))

In [None]:
# test_align_exact(ds_lpj_fix, ds_mi_sel) # Still not aligned

In [None]:
# test = (ds_mi_sel.lon.values == ds_lpj_fix.lon.values)
# test[:5] # Still not matching

In [None]:
# Force Micasa to fit lpj???
ds_mi_sel = ds_mi_sel.reindex(lat=ds_lpj_sel.lat, lon=ds_lpj_sel.lon, method='nearest')
ds_lpj_align, ds_mi_align = test_align_exact(ds_lpj_sel, ds_mi_sel)

In [None]:
ds_lpj_align = ds_lpj_align.chunk(chunk_config)
ds_lpj_align

## Start comparisons

In [None]:
print(ds_lpj_align["dnpp"].attrs, ds_mi_align["NPP"].attrs, sep="\n\n")

#  FFFFFFFFF they are still the wrong units

In [None]:
result = ds_lpj_align["dnpp"]/86400 - ds_mi_align["NPP"] # Conversion? did i do that right
result

In [None]:
result.time

In [None]:
# Test 2024
result_sel = result.sel(time="2024")
result_sel

In [None]:
result_loaded = result_sel.compute()

In [None]:
from matplotlib import colors
divnorm=colors.TwoSlopeNorm(vmin=-5., vcenter=0., vmax=10)

In [None]:
result_loaded.hvplot(
    x="lon", y="lat",
    groupby="time", 
    cmap="coolwarm", 
    # vmin=-1.5e-7, vmax=1.5e-7
    clim = (-1.5e-7, 1.5e-7),
    height=400, width=800,
    widget_location="bottom",
    title="NPP Difference (LPJ-MiCASA)"
)
# This won't work with the widget idk. I'm going to make regular mpl plots in a script

## Plotting tests

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(12, 8), subplot_kw= {'projection': proj})
# plot = ds_mi_chunk["NPP"].isel(time=0).plot(ax=ax,
#                                       cbar_kwargs=dict(shrink=0.6)
#                                      )
# plt.suptitle("MiCASA test")

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(12, 8), subplot_kw= {'projection': proj})
# plot = ds_lpj_combined["dnpp"].isel(time=0).plot(ax=ax, 
#                                                 cbar_kwargs=dict(shrink=0.6))

In [None]:
# ds_sel = ds["dnpp"].isel(time=0)

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(10, 8), subplot_kw= {'projection': proj})
# ds_sel.plot(ax=ax, transform=ccrs.PlateCarree())

In [None]:
# ds_sel.plot()