In [None]:
# Imports

import os
import sys
import numpy as np
import netCDF4 as nc
from tqdm.auto import tqdm # progress bar
import xarray as xr
import pandas as pd
from eofs.xarray import Eof
from cartopy.crs import PlateCarree
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs

import warnings
warnings.filterwarnings("ignore")

'''
# Import these here just to save time running each separately
from ipynb.fs.full import GRACE_processing
from ipynb.fs.full import ECCO_processing
from ipynb.fs.full import Interpolate_GRACE_to_ECCO
from ipynb.fs.full import ECCO_to_GRACE
from ipynb.fs.full import ECCO_to_GRACE_as_ECCO
from ipynb.fs.full import ECCOv4r5_ctrl
'''

# Imports needed for this code
from ipynb.fs.full import ECCO_time_sets
from ipynb.fs.full import Plots
from ipynb.fs.full import RMS
from ipynb.fs.full import Filter  # For trends and MSC
from ipynb.fs.full import Utils

import xesmf as xe
#sys.path.append('ECCOv4-py/ECCOv4-py')
import ecco_v4_py as ecco
#from config import input_dir, input_dir_ecco_grid

### Read in map data

In [None]:
input_dir = '/glade/u/home/mengnanz/p2702/scripts/saved_data/'
output_dir = './saved_data/'

input_dir_grace = '/glade/work/mengnanz/TELLUS_GRAC-GRFO_MASCON_CRI_GRID_RL06.1_V3_202404/'  # Contains this: 'GRCTellus.JPL.200204_202211.GLO.RL06.1M.MSCNv03CRI.nc'

input_dir_ecco_v4r5_ext = '/glade/work/mengnanz/V4r5_ext_2020_2024_Feb/OBP_mon_mean_202001_202402'  
input_dir_ecco_v4r5_to2019 = '/glade/u/home/mengnanz/p2375_bp_seasonal_cycle/input_dir/v4r5_to2019/OBP_mon_mean'
input_dir_ecco_ctrl = '/glade/work/mengnanz/V4r5_ctrl/OBP_mon'
input_dir_ecco_grid = '/glade/u/home/mengnanz/p2375_bp_seasonal_cycle/input_dir/ECCOllc90/r5_nctiles_grid' # contains ECCO-GRID.nc

In [None]:
# Open data sets

#aligned_grace = xr.open_dataset(os.path.join(input_dir, 'aligned_grace.nc'))
aligned_ecco = xr.open_dataset(os.path.join(input_dir, 'aligned_ecco.nc'))
early_ecco = xr.open_dataset(os.path.join(input_dir, 'ecco_199201-201912-rm-means.nc'))
all_ecco = xr.open_dataset(os.path.join(input_dir, 'pb_ECCO_all_data.nc'))
ecco_ctrl = xr.open_dataset(os.path.join(input_dir, 'pb_ECCO_ctrl.nc'))

ecco_grid = xr.open_dataset(os.path.join(input_dir_ecco_grid, 'ECCO-GRID.nc'))

weight = xr.open_dataset(os.path.join(input_dir, 'ecco_early_weight.nc'))


In [None]:
grace_pb = xr.open_dataset('/glade/work/netige/Data/GRACE/pb_graceV2.nc')

In [None]:
grace_pb_v2_aligned = xr.open_dataset('/glade/work/netige/Data/GRACE/pb_graceV2_aligned.nc')

In [None]:
#convert GRACE dates to normal format

time_vals = np.array(grace_pb_v2_aligned.time)

# Convert decimal years to datetime format
base_year = np.floor(time_vals).astype(int)  # Extract the integer year part

fractional_part = time_vals - base_year      # Get the fractional part of the year

mon = (fractional_part*12)+1
mon = mon.round(0)

# Convert year and month to string before concatenation
year_str = base_year.astype(str)
month_str = np.char.zfill(mon.astype(str), 2)  # Ensure two-digit month formatting

# Proper string concatenation using np.char.add
date_str = np.char.add(np.char.add(year_str, '-'), np.char.add(month_str, '-01'))

# Convert to datetime
dates = pd.to_datetime(date_str)

aligned_grace = grace_pb_v2_aligned.assign_coords(time=dates)

In [None]:
ECCO_pb_v4r5_preprocessed = xr.open_dataset(os.path.join(input_dir, 'ecco_199201-201912-rm-means.nc'))
#ECCO_pb_v4r5_preprocessed = xr.open_dataset(os.path.join(input_dir, 'ecco_199201-201912-raw.nc'))
ECCO_pb_ctrl_preprocessed = xr.open_dataset(os.path.join(input_dir, 'pb_ECCO_ctrl.nc'))


In [None]:
eccor5_grid = ecco.load_ecco_grid_nc(input_dir_ecco_grid, 'ECCO-GRID.nc')

In [None]:
# GRACE put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = aligned_grace.__xarray_dataarray_variable__  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='pb'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_GRACE_pp_rg = regridded_data_da


In [None]:
aligned_GRACE_pp_rg= aligned_GRACE_pp_rg.sel(time=slice("1992-01-01", "2019-12-01"))


In [None]:
#v4r5 data

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = ECCO_pb_v4r5_preprocessed.pb  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='pb'
)


import pandas as pd

# Generate the new time coordinate
new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
ECCO_pb_v4r5_pp_rg = regridded_data_da.assign_coords(time=new_time)


In [None]:
#ctrl data

# Input dataset and grid
data = ECCO_pb_ctrl_preprocessed.pb  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='pb'
)


import pandas as pd

# Generate the new time coordinate
new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
ECCO_pb_ctrl_pp_rg = regridded_data_da.assign_coords(time=new_time)

In [None]:
grace_pb_error = xr.open_dataset('/glade/work/netige/P2702/GRACE_Error_and_Scaled.nc')

In [None]:
# Regridding GRACE Error to ECCO GRID

import numpy as np
import xarray as xr
import ecco_v4_py as ecco

# Load GRACE error dataset
ds = xr.open_dataset("GRACE_Error_and_Scaled.nc")

# Extract lat/lon and fields
lon = ds.lon
lat = ds.lat
error = ds.error
error_scaled = ds.error_scaled

# Shift longitudes from 0–360 to –180–180 if needed
lon_corrected = lon.where(lon <= 180, lon - 360)

# Apply finite mask
valid_mask = np.isfinite(error) & np.isfinite(lat) & np.isfinite(lon_corrected)
error_masked = error.where(valid_mask)
error_scaled_masked = error_scaled.where(valid_mask)

# Define regridding target grid (0.5° resolution)
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180
new_grid_delta_lat = 0.5
new_grid_delta_lon = 0.5

# Regrid `error`
_, _, bin_lats, bin_lons, regridded_error = ecco.resample_to_latlon(
    lon_corrected, lat, error_masked,
    new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
    new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
    mapping_method='nearest_neighbor',
    fill_value=np.nan
)

# Regrid `error_scaled`
_, _, _, _, regridded_error_scaled = ecco.resample_to_latlon(
    lon_corrected, lat, error_scaled_masked,
    new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
    new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
    mapping_method='nearest_neighbor',
    fill_value=np.nan
)

# Infer 1D lat/lon from the output grid shape
n_lat, n_lon = regridded_error.shape
lat_1d = np.linspace(new_grid_min_lat, new_grid_min_lat + new_grid_delta_lat * (n_lat - 1), n_lat)
lon_1d = np.linspace(new_grid_min_lon, new_grid_min_lon + new_grid_delta_lon * (n_lon - 1), n_lon)

# Create regridded DataArrays
error_da = xr.DataArray(
    data=regridded_error,
    dims=["latitude", "longitude"],
    coords={"latitude": lat_1d, "longitude": lon_1d},
    name="error"
)

error_scaled_da = xr.DataArray(
    data=regridded_error_scaled,
    dims=["latitude", "longitude"],
    coords={"latitude": lat_1d, "longitude": lon_1d},
    name="error_scaled"
)

# Combine into a dataset
grace_error_rg = xr.Dataset({
    "error": error_da,
    "error_scaled": error_scaled_da
})

# Optional: save to NetCDF
# grace_error_rg.to_netcdf("GRACE_Error_Regridded.nc")

# Optional: plot
#grace_error_rg.error.plot(robust=True, cmap="plasma")


In [None]:
# Extract GRACE time coordinate
grace_time = aligned_GRACE_pp_rg.time

# Reindex ECCO datasets to GRACE time
ECCO_pb_v4r5_common = ECCO_pb_v4r5_pp_rg.sel(time=slice(grace_time.min(), grace_time.max()))
ECCO_pb_ctrl_common = ECCO_pb_ctrl_pp_rg.sel(time=slice(grace_time.min(), grace_time.max()))

# Optional: ensure exact matching times (intersection, in case ECCO has extra months)
common_time = np.intersect1d(ECCO_pb_v4r5_common.time.values,
                             aligned_GRACE_pp_rg.time.values)

ECCO_pb_v4r5_common = ECCO_pb_v4r5_common.sel(time=common_time)
ECCO_pb_ctrl_common = ECCO_pb_ctrl_common.sel(time=common_time)
GRACE_common         = aligned_GRACE_pp_rg.sel(time=common_time)

In [None]:
grace_std = GRACE_common.std(dim='time', skipna=True).load()
ecco_v4r5_std = ECCO_pb_v4r5_common.std(dim='time', skipna=True).load()
ecco_ctrl_std = ECCO_pb_ctrl_common.std(dim='time', skipna=True).load()
grace_error = grace_error_rg.error.load()

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap, BoundaryNorm

# Choose number of discrete bins
n_bins = 10   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("jet", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)

# Data and titles
datasets = [
    (grace_error, "GRACE Error"),
    (grace_std, "GRACE Std. Dev."),
    (ecco_ctrl_std, "ECCO Control Std. Dev."),
    (ecco_v4r5_std, "ECCO v4r5 Std. Dev.")
    
]
labels = ['(a)', '(b)', '(c)', '(d)']
vmins = [0, 0, 0, 0]
vmaxs = [5, 5, 5, 5]

# Create 2x2 panel
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6),
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)},
                         dpi=150)

# Flatten axes for easy iteration
axes = axes.flatten()

for i, (ax, (data, title), vmin, vmax) in enumerate(zip(axes, datasets, vmins, vmaxs)):
    pc = ax.pcolormesh(data.longitude, data.latitude, data,
                       transform=ccrs.PlateCarree(),
                       cmap=cmap_disc, vmin=vmin, vmax=vmax)
    ax.set_global()
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='#d9d9d9'))
    ax.set_title(title, fontsize=10, fontweight='bold')

    # Gridlines with conditional axis labels
    gl = ax.gridlines(draw_labels=True, color='gray', linestyle='dashed', linewidth=0.01)
    gl.right_labels = False
    gl.top_labels = False
    gl.xlabels_bottom = i // 2 == 1  # Only bottom row
    gl.ylabels_left = i % 2 == 0     # Only left column
    gl.xlabel_style = {'fontsize': 8}
    gl.ylabel_style = {'fontsize': 8}

    ax.text(-0.01, 1.08, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Shared colorbar
cbar_ax = fig.add_axes([0.25, 0.05, 0.5, 0.02])  # [left, bottom, width, height]
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal')
cbar.set_label("Ocean Bottom Pressure (cm)", fontsize=12, fontweight="bold")

plt.tight_layout(rect=[0, 0.1, 1, 1])  # Leave space for colorbar
plt.savefig("GRACE_ECCO_Std_Error_Fig_GRACEV2.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
eco_diff = ecco_v4r5_std - ecco_ctrl_std


In [None]:
eco_diff

In [None]:
eco_diff.plot(robust=True)

In [None]:
# Choose number of discrete bins
n_bins = 30   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-3, 3 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


data = eco_diff

fig = plt.figure(figsize=(11, 6), dpi=300)
ax = plt.subplot(1, 1, 1, projection=ccrs.PlateCarree())   # <- GeoAxes


pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
ax.set_global() 
    
    # Gridlines
gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
gl.xlabels_bottom = True
gl.ylabels_left = True

# Add land feature
ax.add_feature(cfeature.NaturalEarthFeature(
    'physical', 'land', '110m',
    linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
ax.set_title("ECCO v4r5 std. minus ctrl std", fontsize=12)


# Colorbar
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("(cm)")

plt.savefig("ECCO_v4r5_std_ctrl_std_dif.png", dpi=500, bbox_inches='tight')

plt.show()
#plt.tight_layout(rect=[0, 0.1, 1, 1])


#dif.plot()

In [None]:
cor = xr.corr(ECCO_pb_v4r5_common, ECCO_pb_ctrl_common, dim="time").compute()

In [None]:
# Choose number of discrete bins
n_bins = 30   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-1, 1 + 0.1, 0.1)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


data = cor

fig = plt.figure(figsize=(11, 6), dpi=300)
ax = plt.subplot(1, 1, 1, projection=ccrs.PlateCarree())   # <- GeoAxes


pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
ax.set_global() 
    
    # Gridlines
gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
gl.xlabels_bottom = True
gl.ylabels_left = True

# Add land feature
ax.add_feature(cfeature.NaturalEarthFeature(
    'physical', 'land', '110m',
    linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
ax.set_title("ECCO v4r5 std. minus ctrl std", fontsize=12)


# Colorbar
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("(cm)")

#plt.savefig("ECCO_v4r5_std_ctrl_std_dif.png", dpi=500, bbox_inches='tight')

plt.show()
#plt.tight_layout(rect=[0, 0.1, 1, 1])


#dif.plot()

In [None]:
import numpy as np

# Step 1: Get numpy datetime64 arrays
grace_times = aligned_GRACE_pp_rg.time.values
ecco_times = ECCO_pb_v4r5_pp_rg.time.values

# Step 2: Find common dates
common_times = np.intersect1d(grace_times, ecco_times)

# Step 3: Subset all three datasets to only common time steps
grace_common     = aligned_GRACE_pp_rg.sel(time=common_times)
ecco_v4r5_common = ECCO_pb_v4r5_pp_rg.sel(time=common_times)
ecco_ctrl_common = ECCO_pb_ctrl_pp_rg.sel(time=common_times)

grace_min_v4r5 = grace_common - ecco_v4r5_common
grace_min_ctrl = grace_common - ecco_ctrl_common
ctrl_min_v4r5 = ecco_ctrl_common - ecco_v4r5_common



In [None]:
grace_min_v4r5_std = grace_min_v4r5.std (dim = "time")
grace_min_ctrl_std = grace_min_ctrl.std (dim = "time")
ctrl_min_v4r5_std = ctrl_min_v4r5.std (dim = "time")

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# Difference dataset
grace_dif = grace_min_v4r5_std - grace_min_ctrl_std

# Dataset and metadata
data_list = [grace_min_v4r5_std, grace_min_ctrl_std, ctrl_min_v4r5_std, grace_dif]
title_list = ["Std.dev.(GRACE - v4r5)", 
              "Std.dev.(GRACE - ctrl)", 
              "Std.dev.(ctrl - v4r5)", 
              "Std. Dev. (GRACE-v4r5) minus Std. Dev. (GRACE-ctrl)"]
labels = ['(a)', '(b)', '(c)', '(d)']

# Choose number of discrete bins
n_bins = 10   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("jet", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(0, 5.0 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))

#norm_disc = BoundaryNorm(boundaries, ncolors=len(custom_colors))

# Discrete diverging colormap (7 steps from -0.7 to 0.7)
#div_colors = ["#053061", "#2166ac", "#67a9cf", "#f7f7f7","#ef8a62",  "#b2182b",  "#67001f"]
div_colors = [
    "#2166ac",
    "#4393c3",
    "#92c5de",
    "#d1e5f0",
    "#fddbc7",
    "#f4a582",
    "#d6604d",
    "#b2182b"
]


cmap_div_disc = ListedColormap(div_colors)
boundaries_diff = np.linspace(-0.7, 0.7, len(div_colors) + 1)
norm_diff = BoundaryNorm(boundaries_diff, ncolors=len(div_colors))

# Create subplots
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 4.5), dpi=300,
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)},
                         gridspec_kw={'hspace': 0.01})  # Reduce from default ~0.2


axes = axes.flatten()

for i, (ax, data, title) in enumerate(zip(axes, data_list, title_list)):
    # Choose colormap
    if i < 3:
        pc = ax.pcolormesh(data.longitude, data.latitude, data,
                           transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm_disc)
    else:
        pc = ax.pcolormesh(data.longitude, data.latitude, data,
                   transform=ccrs.PlateCarree(),
                   cmap=cmap_div_disc,
                   norm=norm_diff)

    ax.set_global()
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i // 2 == 1  # Only bottom row
    gl.ylabels_left = i % 2 == 0     # Only left column
    gl.right_labels = False
    gl.top_labels = False
    gl.xlabel_style = {'fontsize': 7}
    gl.ylabel_style = {'fontsize': 7}

    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'))

    ax.set_title(title, fontsize=8, fontweight='bold')
    ax.text(-0.09, 1.01, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # Add vertical colorbar to the right of each panel
    cbar = fig.colorbar(pc, ax=ax, orientation='vertical', pad=0.05, fraction=0.05, shrink=0.7)
    if i < 3:
        cbar.set_ticks(boundaries)
        cbar.ax.set_yticklabels([f'{b:.1f}' for b in boundaries],fontsize=7)
        #cbar.set_label("OBP Std. Dev. (cm)", fontsize=8)
    else:
        cbar.set_ticks(boundaries_diff)
        cbar.ax.set_yticklabels([f'{b:.3f}' for b in boundaries_diff], fontsize=7)

        #cbar.set_label("GRACE Std. Dev. Diff. (cm)", fontsize=8)
                # Add zero-line contour
        cs = ax.contour(data.longitude, data.latitude, data,
                        levels=[0], colors='black', linewidths=0.5, 
                        #linestyles = "dashed",
                        transform=ccrs.PlateCarree())
        #ax.clabel(cs, fmt='0', inline=True, fontsize=6)


plt.tight_layout()
plt.savefig("OBP_Dif_Std_2x2_GRACEV2.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

# Define a function to compute trends manually over latitude and longitude indices
def compute_trends_manually(data_array):

    # Assuming 'pb' is your xarray DataArray with dims ["time", "latitude", "longitude"]
    time = data_array.coords["time"].values

    time_numeric = (time - time[0]) / np.timedelta64(1, 'D')  # Convert to days

    lat_size, lon_size = data_array.sizes["latitude"], data_array.sizes["longitude"]
    trend_array = np.full((lat_size, lon_size), np.nan)

    for lat_idx in range(lat_size):
        for lon_idx in range(lon_size):
            y = data_array[:, lat_idx, lon_idx].values  # Extract time series
            mask = ~np.isnan(y)  # Remove NaNs

            if mask.sum() > 1:  # Need at least two points for regression
                slope, _, _, _, _ = linregress(time_numeric[mask], y[mask])
                trend_array[lat_idx, lon_idx] = slope

    return trend_array

In [None]:
%%time
# Compute trends

grace_common_trend     = compute_trends_manually(grace_common)
ecco_v4r5_common_trend = compute_trends_manually(ecco_v4r5_common)
ecco_ctrl_common_trend = compute_trends_manually(ecco_ctrl_common)

grace_min_v4r5_trend = compute_trends_manually(grace_min_v4r5)
grace_min_ctrl_trend = compute_trends_manually(grace_min_ctrl)
ctrl_min_v4r5_trend = compute_trends_manually(ctrl_min_v4r5)



In [None]:
grace_common_trend_pyr = grace_common_trend * 365
ecco_v4r5_common_trend_pyr = ecco_v4r5_common_trend * 365
ecco_ctrl_common_trend_pyr = ecco_ctrl_common_trend * 365
grace_min_v4r5_trend_pyr = grace_min_v4r5_trend * 365
grace_min_ctrl_trend_pyr = grace_min_ctrl_trend * 365
ctrl_min_v4r5_trend_pyr = ctrl_min_v4r5_trend * 365

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt

# Define the datasets and titles
data_list = [grace_common_trend_pyr, 
             ecco_v4r5_common_trend_pyr, 
             ecco_ctrl_common_trend_pyr,
            grace_min_v4r5_trend_pyr,
            grace_min_ctrl_trend_pyr,
            ctrl_min_v4r5_trend_pyr]

title_list = ["GRACE Trends", 
              "ECCO v4r5 Trends", 
              "ECCO ctrl Trends",
             "GRACE - v4r5 Trends",
             "GRACE - ctrl Trends",
             "ctrl - v4r5 Trends"]

labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 4.5), dpi=500, subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(grace_common.longitude, 
                       grace_common.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap='coolwarm',
                       vmin=-0.3, 
                       vmax=0.3)
    ax.set_global() 
    
    # Add gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    if i in [0,1,2]:
        gl.xlabels_bottom = False
    else:
        gl.xlabels_bottom = True

    if i in [1,2,4,5]:
        gl.ylabels_left = False
    else:
        gl.ylabels_left = True
        
    gl.gridlines = False
    
    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='darkgray'))
    
    ax.set_title(title, fontsize=12)

     # Add (a)-(f) label to top-left
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Add colorbar below the bottom figure
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal')
cbar.set_label("Ocean Bottom Pressure Trend (cm/yr)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("OBP_Trends_Fig_GRACEV2.png", dpi=500, bbox_inches='tight')
plt.show()


### Compute ECCO trends for the Full ECCO Period

In [None]:
diff = ECCO_pb_ctrl_pp_rg - ECCO_pb_v4r5_pp_rg

v4r5_full_trend = compute_trends_manually(ECCO_pb_v4r5_pp_rg)
ctrl_full_trend = compute_trends_manually(ECCO_pb_ctrl_pp_rg)
diff_full_trend = compute_trends_manually(diff)

v4r5_full_trend_pyr = v4r5_full_trend * 365
ctrl_full_trend_pyr = ctrl_full_trend * 365
diff_full_trend_pyr = diff_full_trend * 365

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt

# Define the datasets and titles
data_list = [v4r5_full_trend_pyr, 
             ctrl_full_trend_pyr, 
             diff_full_trend_pyr]

title_list = ["ECCO v4r5 Trends (Full Period)", 
              "ECCO ctrl Trends (Full Period)",
             "ctrl - v4r5 Trends (Full Period)"]

labels = ['(a)', '(b)', '(c)']

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 2.5), dpi=500, subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(grace_common.longitude, 
                       grace_common.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap='coolwarm',
                       vmin=-0.3, 
                       vmax=0.3)
    ax.set_global() 
    
    # Add gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    if i in [0,1,2]:
        gl.xlabels_bottom = True
    else:
        gl.xlabels_bottom = True

    if i in [1,2,4,5]:
        gl.ylabels_left = False
    else:
        gl.ylabels_left = True
        
    gl.gridlines = False
    
    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='darkgray'))
    
    ax.set_title(title, fontsize=12)

     # Add (a)-(f) label to top-left
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Add colorbar below the bottom figure
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.04])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal')
cbar.set_label("Ocean Bottom Pressure Trend (cm/yr)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("ECCO_Full_OBP_Trends.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
# Trend dif figure
#Full minus GRACE Period

v4r5_dif = v4r5_full_trend_pyr - ecco_v4r5_common_trend_pyr
ctrl_dif = ctrl_full_trend_pyr - ecco_ctrl_common_trend_pyr
diff_dif = diff_full_trend_pyr - grace_min_v4r5_trend_pyr

data_list = [v4r5_dif, 
             ctrl_dif, 
             diff_dif]

title_list = ["ECCO v4r5 Trends (Full Period - GRACE Period)", 
              "ECCO ctrl Trends (Full Period- GRACE Period)",
             "ctrl - v4r5 Trends (Full Period- GRACE Period)"]

labels = ['(a)', '(b)', '(c)']

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 2.5), dpi=500, subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(grace_common.longitude, 
                       grace_common.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap='coolwarm',
                       vmin=-0.3, 
                       vmax=0.3)
    ax.set_global() 
    
    # Add gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    if i in [0,1,2]:
        gl.xlabels_bottom = True
    else:
        gl.xlabels_bottom = True

    if i in [1,2,4,5]:
        gl.ylabels_left = False
    else:
        gl.ylabels_left = True
        
    gl.gridlines = False
    
    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='darkgray'))
    
    ax.set_title(title, fontsize=7)

     # Add (a)-(f) label to top-left
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Add colorbar below the bottom figure
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.04])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal')
cbar.set_label("Ocean Bottom Pressure Trend (cm/yr)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("ECCO_Full_min_GRACE_OBP_Trends.png", dpi=500, bbox_inches='tight')
plt.show()

### Compare GRACE Raw and GRACE Corrected Trends

In [None]:
grace_pb_cor_aligned = xr.open_dataset('/glade/work/netige/Data/GRACE/pb_graceV2_GRD_corrected_aligned.nc')

#convert GRACE dates to normal format

time_vals = np.array(grace_pb_cor_aligned.time)

# Convert decimal years to datetime format
base_year = np.floor(time_vals).astype(int)  # Extract the integer year part

fractional_part = time_vals - base_year      # Get the fractional part of the year

mon = (fractional_part*12)+1
mon = mon.round(0)

# Convert year and month to string before concatenation
year_str = base_year.astype(str)
month_str = np.char.zfill(mon.astype(str), 2)  # Ensure two-digit month formatting

# Proper string concatenation using np.char.add
date_str = np.char.add(np.char.add(year_str, '-'), np.char.add(month_str, '-01'))

# Convert to datetime
dates = pd.to_datetime(date_str)

grace_pb_cor_aligned = grace_pb_cor_aligned.assign_coords(time=dates)


# GRACE put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = grace_pb_cor_aligned.pb_corrected  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='pb'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_GRACE_cor_pp_rg = regridded_data_da

GRACE_cor_common         = aligned_GRACE_cor_pp_rg.sel(time=common_time)

grace_cor_common_trend     = compute_trends_manually(GRACE_cor_common)
grace_cor_common_trend_pyr = grace_cor_common_trend * 365

In [None]:
dif = grace_common_trend_pyr - grace_cor_common_trend_pyr

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# --- discrete setup for panel (c): 10 bins across [-0.3, 0.3] ---
# Choose number of discrete bins
n_bins = 25   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-0.3, 0.3 + 0.05, 0.05)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


# Define the datasets and titles
data_list = [grace_common_trend_pyr, 
             grace_cor_common_trend_pyr]

title_list = ["GRACE Raw Trends",
             "GRACE Corrected Trends"]

labels = ['(a)', '(b)']

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 3.5), dpi=500, subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(grace_common.longitude, 
                       grace_common.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
    ax.set_global() 
    
    # Add gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    if i in [0,1,2]:
        gl.xlabels_bottom = True
    else:
        gl.xlabels_bottom = True

    if i in [1,2,4,5]:
        gl.ylabels_left = False
    else:
        gl.ylabels_left = True
        
    gl.gridlines = False
    
    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='darkgray'))
    
    ax.set_title(title, fontsize=12)

     # Add (a)-(f) label to top-left
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Add colorbar below the bottom figure
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal',
                   ticks=boundaries)
cbar.set_label("Ocean Bottom Pressure Trend (cm/yr)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
#plt.savefig("OBP_Trends_Fig_GRACEV2.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
dif.shape

In [None]:
import numpy as np
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# --- discrete setup for panel (c): 10 bins across [-0.3, 0.3] ---
n_bins = 10
boundaries = np.linspace(-0.3, 0.3, n_bins + 1)          # 11 edges for 10 bins
cmap_disc = plt.cm.get_cmap("coolwarm", n_bins)           # 10 discrete colors
norm_disc = BoundaryNorm(boundaries, ncolors=n_bins)

# --- inputs ---
data_list  = [grace_common_trend_pyr,  # (a)
              grace_cor_common_trend_pyr,  # (b)
              dif]  # (c)
title_list = ["GRACE Raw Trends", "GRACE Corrected Trends", "Raw - Corrected"]
labels     = ['(a)', '(b)', '(c)']

# shared limits for (a) & (b)
vmin_ab, vmax_ab = -0.3, 0.3

fig, axes = plt.subplots(
    1, 3, figsize=(12, 3), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

pcs = []
for i, (ax, data, title) in enumerate(zip(axes, data_list, title_list)):
    if i < 2:
        pc = ax.pcolormesh(
            grace_common.longitude, grace_common.latitude, data,
            transform=ccrs.PlateCarree(), cmap="coolwarm",
            vmin=vmin_ab, vmax=vmax_ab
        )
    else:
        pc = ax.pcolormesh(
            grace_common.longitude, grace_common.latitude, data,
            transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm_disc
        )
    pcs.append(pc)

    ax.set_global()
    ax.add_feature(
        cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                     linewidth=0.5, edgecolor='black', facecolor='darkgray')
    )
    ax.gridlines(color='gray', linestyle='dashed', linewidth=0.2, draw_labels=False)
    ax.set_title(title, fontsize=11)
    ax.text(-0.01, 1.07, labels[i], transform=ax.transAxes,
            fontsize=9, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.0, alpha=0.6))

# Leave space at bottom for two horizontal colorbars
plt.tight_layout(rect=[0, 0.18, 1, 1])

# --- Horizontal colorbar for (a)+(b) sharing continuous limits ---
pos_a = axes[0].get_position()
pos_b = axes[1].get_position()
left  = min(pos_a.x0, pos_b.x0)
right = max(pos_a.x1, pos_b.x1)
bottom = min(pos_a.y0, pos_b.y0)
cax_ab = fig.add_axes([left, bottom - 0.08, right - left, 0.025])
cb_ab = fig.colorbar(pcs[0], cax=cax_ab, orientation='horizontal')
cb_ab.set_label("Ocean Bottom Pressure Trend (cm/yr)")

# --- Horizontal colorbar for (c) using discrete bins ---
pos_c = axes[2].get_position()
cax_c = fig.add_axes([pos_c.x0, pos_c.y0 - 0.08, pos_c.width, 0.025])
cb_c = fig.colorbar(pcs[2], cax=cax_c, orientation='horizontal',
                    ticks=boundaries[::2])  # every other edge to avoid clutter
cb_c.set_label("Ocean Bottom Pressure Trend (cm/yr)")

plt.show()


In [None]:
def mean_seasonal_cycle(ds, time_dim="time", groupby="time.month"):
    """
    Remove the seasonal mean cycle from an xarray dataset.

    Parameters:
    ds (xr.Dataset or xr.DataArray): Input dataset or data array.
    time_dim (str): Name of the time dimension.
    groupby (str): Grouping method, default is 'time.month' (monthly climatology).

    Returns:
    xr.Dataset or xr.DataArray: Dataset or data array with the seasonal cycle removed.
    """
    climatology = ds.groupby(groupby).mean(dim=time_dim)
    return climatology

In [None]:
%%time
# Compute MSC and their std.

grace_common_msc    = mean_seasonal_cycle(grace_common)
ecco_v4r5_common_msc = mean_seasonal_cycle(ecco_v4r5_common)
ecco_ctrl_common_msc = mean_seasonal_cycle(ecco_ctrl_common)

#grace_min_v4r5_msc = mean_seasonal_cycle(grace_min_v4r5)
#grace_min_ctrl_msc = mean_seasonal_cycle(grace_min_ctrl)
#ctrl_min_v4r5_msc = mean_seasonal_cycle(ctrl_min_v4r5)


grace_min_v4r5_msc = grace_common_msc - ecco_v4r5_common_msc
grace_min_ctrl_msc = grace_common_msc - ecco_ctrl_common_msc
ctrl_min_v4r5_msc =  ecco_ctrl_common_msc - ecco_v4r5_common_msc


# Compute residual std.
grace_common_msc_std = grace_common_msc.std(dim = "month") 
ecco_v4r5_common_msc_std = ecco_v4r5_common_msc.std(dim = "month") 
ecco_ctrl_common_msc_std = ecco_ctrl_common_msc.std(dim = "month") 
grace_min_v4r5_msc_std = grace_min_v4r5_msc.std(dim = "month") 
grace_min_ctrl_msc_std = grace_min_ctrl_msc.std(dim = "month") 
ctrl_min_v4r5_msc_std = ctrl_min_v4r5_msc.std(dim = "month") 

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# Define the datasets and titles
data_list = [
    grace_common_msc_std, 
    ecco_v4r5_common_msc_std, 
    ecco_ctrl_common_msc_std,
    grace_min_v4r5_msc_std,
    grace_min_ctrl_msc_std,
    ctrl_min_v4r5_msc_std
]

title_list = [
    "GRACE MSC", 
    "ECCO v4r5 MSC", 
    "ECCO ctrl MSC",
    "GRACE - v4r5 MSC",
    "GRACE - ctrl MSC",
    "ctrl - v4r5 MSC"
]

labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']

# Custom 6-color palette
custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
boundaries = np.arange(0, 3.0 + 0.5, 0.5)
norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 4.5), dpi=300, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap,
                       norm=norm)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 3
    gl.ylabels_left = i % 3 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=12)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("Ocean Bottom Pressure Standard Deviation (cm)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("OBP_MSC_std_Fig_GRACEV2.png", dpi=500, bbox_inches='tight')

plt.show()


In [None]:
#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# Choose number of discrete bins
n_bins = 6   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("jet", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(0, 3.0 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))

# Define the datasets and titles
data_list = [
    grace_common_msc_std, 
    ecco_v4r5_common_msc_std, 
    ecco_ctrl_common_msc_std,
    grace_min_v4r5_msc_std,
    grace_min_ctrl_msc_std,
    ctrl_min_v4r5_msc_std
]

title_list = [
    "GRACE MSC", 
    "ECCO v4r5 MSC", 
    "ECCO ctrl MSC",
    "GRACE - v4r5 MSC",
    "GRACE - ctrl MSC",
    "ctrl - v4r5 MSC"
]

labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 4.5), dpi=300, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 3
    gl.ylabels_left = i % 3 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=12)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("Ocean Bottom Pressure Standard Deviation (cm)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("OBP_MSC_std_Fig_GRACEV2_new_color.png", dpi=500, bbox_inches='tight')

plt.show()


### Exploratory plot for d minus e

In [None]:
# Choose number of discrete bins
n_bins = 10   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-0.4, 0.4 + 0.1, 0.1)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


data = grace_min_v4r5_msc_std - grace_min_ctrl_msc_std

fig = plt.figure(figsize=(11, 6), dpi=300)
ax = plt.subplot(1, 1, 1, projection=ccrs.PlateCarree())   # <- GeoAxes


pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
ax.set_global() 
    
    # Gridlines
gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
gl.xlabels_bottom = True
gl.ylabels_left = True

# Add land feature
ax.add_feature(cfeature.NaturalEarthFeature(
    'physical', 'land', '110m',
    linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
ax.set_title("(GRACE - v4r5) MSC minus (GRACE - ctrl) MSC", fontsize=12)


# Colorbar
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("(cm)")

plt.savefig("Fig4d_minus_Fig4e.png", dpi=500, bbox_inches='tight')

plt.show()
#plt.tight_layout(rect=[0, 0.1, 1, 1])


#dif.plot()

### Similar plot to Figure 4 of ECCO paper but using residuals (trend and MSC removed)
1. Remove MSC and Trend from Each OBP dataset
2. Compute GRACE residual std.
3. Compute ECCO v4r5 and ctrl std.
4. Compute GRACE-v4r5 residual
5. Compute GRACE - ctrl residual
6. Compute ctrl-v4r5 residuls
7. compute std. for 4,5,6
   

In [None]:
# Functions needed
# Define a function to detrend along the time dimension
def detrend_dim(da, dim='time'):
    # Get the time axis as numerical values
    x = np.arange(da.sizes[dim])  # Convert time index to integers

    # Define a function that applies linear detrending
    def detrend_func(y):
        mask = np.isfinite(y)  # Ignore NaNs
        if mask.sum() > 1:  # Ensure there are enough valid points
            slope, intercept, _, _, _ = linregress(x[mask], y[mask])
            return y - (slope * x + intercept)
        else:
            return y  # Return unchanged if insufficient valid data

    return xr.apply_ufunc(
        detrend_func,
        da,
        input_core_dims=[[dim]],
        output_core_dims=[[dim]],
        vectorize=True,
        dask="parallelized"
    )

def remove_seasonal_cycle(ds, time_dim="time", groupby="time.month"):
    """
    Remove the seasonal mean cycle from an xarray dataset.

    Parameters:
    ds (xr.Dataset or xr.DataArray): Input dataset or data array.
    time_dim (str): Name of the time dimension.
    groupby (str): Grouping method, default is 'time.month' (monthly climatology).

    Returns:
    xr.Dataset or xr.DataArray: Dataset or data array with the seasonal cycle removed.
    """
    climatology = ds.groupby(groupby).mean(dim=time_dim)
    return ds.groupby(groupby) - climatology


In [None]:
ecco_v4r5_common_dt = detrend_dim(ecco_v4r5_common)
ecco_ctrl_common_dt = detrend_dim(ecco_ctrl_common)

In [None]:
cor = xr.corr(ecco_v4r5_common, ecco_ctrl_common, dim="time").compute()

In [None]:
cor_dt = xr.corr(ecco_v4r5_common_dt, ecco_ctrl_common_dt, dim="time").compute()

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# --- discrete setup for panel (c): 10 bins across [-0.3, 0.3] ---
# Choose number of discrete bins
n_bins = 30   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-1, 1 + 0.1, 0.1)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


# Define the datasets and titles
data_list = [cor, 
             cor_dt]

title_list = ["With trend",
             "Without trend"]

labels = ['(a)', '(b)']

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 3.5), dpi=500, subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(grace_common.longitude, 
                       grace_common.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
    ax.set_global() 
    
    # Add gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    if i in [0,1,2]:
        gl.xlabels_bottom = True
    else:
        gl.xlabels_bottom = True

    if i in [1,2,4,5]:
        gl.ylabels_left = False
    else:
        gl.ylabels_left = True
        
    gl.gridlines = False
    
    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='darkgray'))
    
    ax.set_title(title, fontsize=12)

     # Add (a)-(f) label to top-left
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Add colorbar below the bottom figure
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal',
                   ticks=boundaries)
cbar.set_label("")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.suptitle("Correlation between ECCO v4r5 and ctrl")
plt.savefig("ECCO_v4r5_ctrl_cor.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
%%time

grace_dt_sr = remove_seasonal_cycle(detrend_dim(grace_common))
v4r5_dt_sr = remove_seasonal_cycle(detrend_dim(ecco_v4r5_common))
ctrl_dt_sr = remove_seasonal_cycle(detrend_dim(ecco_ctrl_common))

grace_min_v4r5_res = grace_dt_sr - v4r5_dt_sr
grace_min_ctrl_res = grace_dt_sr - ctrl_dt_sr
ctrl_min_v4r5_res =  ctrl_dt_sr - v4r5_dt_sr

# Compute Std.
grace_common_res_std = grace_dt_sr.std(dim = "time") 
ecco_v4r5_common_res_std = v4r5_dt_sr.std(dim = "time") 
ecco_ctrl_common_res_std = ctrl_dt_sr.std(dim = "time") 
grace_min_v4r5_res_std = grace_min_v4r5_res.std(dim = "time") 
grace_min_ctrl_res_std = grace_min_ctrl_res.std(dim = "time") 
ctrl_min_v4r5_res_std = ctrl_min_v4r5_res.std(dim = "time") 

### Make a plot that compares raw MSC and residual variance to get an idea if the detrending and deseasoning functions are correct.

In [None]:
grace_common_std = grace_common.std(dim = 'time')
ecco_v4r5_common_std = ecco_v4r5_common.std(dim = 'time')
ecco_ctrl_common_std = ecco_ctrl_common.std(dim = 'time')

#grace_common_msc_std = grace_common_msc.std(dim = "month") 
#ecco_v4r5_common_msc_std = ecco_v4r5_common_msc.std(dim = "month") 
#ecco_ctrl_common_msc_std = ecco_ctrl_common_msc.std(dim = "month") 

#grace_common_res_std = grace_dt_sr.std(dim = "time") 
#ecco_v4r5_common_res_std = v4r5_dt_sr.std(dim = "time") 
#ecco_ctrl_common_res_std = ctrl_dt_sr.std(dim = "time")

data_list = [grace_common_std, ecco_v4r5_common_std, ecco_ctrl_common_std,
             grace_common_msc_std, ecco_v4r5_common_msc_std, ecco_ctrl_common_msc_std,
             grace_common_res_std, ecco_v4r5_common_res_std, ecco_ctrl_common_res_std]

In [None]:
# basic check plot

#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# Choose number of discrete bins
n_bins = 20   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("jet", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(0, 5.0 + 0.25, 0.25)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


title_list = ['GRACE std', 'v4r5_std', 'ctrl_std',
              'GRACE MSC std', 'v4r5 MSC std', 'ctrl MSC std',
              'GRACE res std', 'v4r5 res std', 'ctrl MSC std']

labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)']

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 8), dpi=300, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 3
    gl.ylabels_left = i % 3 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=12)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.2f}' for b in boundaries])
cbar.set_label("Ocean Bottom Pressure Standard Deviation (cm)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
#plt.savefig("OBP_MSC_std_Fig_GRACEV2_new_color.png", dpi=500, bbox_inches='tight')

plt.show()


In [None]:
GRACE_dif = grace_common_std - grace_common_res_std
v4r5_dif = ecco_v4r5_common_std - ecco_v4r5_common_res_std
ctrl_dif = ecco_ctrl_common_std - ecco_ctrl_common_res_std

data_list = [GRACE_dif, v4r5_dif, ctrl_dif]

title_list = ['GRACE dif', 'v4r5 dif', 'ctrl dif']

# basic check plot

#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# Choose number of discrete bins
n_bins = 25   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-3.0, 3.0 + 0.25, 0.25)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))



labels = ['(a)', '(b)', '(c)']

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 3), dpi=300, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 3
    gl.ylabels_left = i % 3 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=12)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.2f}' for b in boundaries])
cbar.set_label("Ocean Bottom Pressure Standard Deviation (cm)")
plt.suptitle('Difference between Raw and Residual Std')
plt.tight_layout(rect=[0, 0.1, 1, 1])
#plt.savefig("OBP_MSC_std_Fig_GRACEV2_new_color.png", dpi=500, bbox_inches='tight')

plt.show()



In [None]:
# Check if there are any negative values

ctrl_dif.plot.hist()

In [None]:


import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# Choose number of discrete bins
n_bins = 8   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("jet", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(0, 4.0 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))

# Define the datasets and titles
data_list = [
    grace_common_res_std, 
    ecco_v4r5_common_res_std, 
    ecco_ctrl_common_res_std,
    grace_min_v4r5_res_std,
    grace_min_ctrl_res_std,
    ctrl_min_v4r5_res_std
]

title_list = [
    "GRACE Residual", 
    "ECCO v4r5 Residual", 
    "ECCO ctrl Residual",
    "GRACE Residual - v4r5 Residual",
    "GRACE Residual - ctrl Residual",
    "ctrl Residual - v4r5 Residual"
]

labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 4.5), dpi=300, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 3
    gl.ylabels_left = i % 3 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=12)

    # Subplot labels (a)-(f)
    ax.text(-0.1, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("Ocean Bottom Pressure Standard Deviation (cm)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("OBP_residual_std_Fig_GRACEV2.png", dpi=500, bbox_inches='tight')

plt.show()


In [None]:
# Choose number of discrete bins
n_bins = 10   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-0.4, 0.4 + 0.1, 0.1)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


data = grace_min_v4r5_res_std - grace_min_ctrl_res_std

fig = plt.figure(figsize=(11, 6), dpi=300)
ax = plt.subplot(1, 1, 1, projection=ccrs.PlateCarree())   # <- GeoAxes


pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
ax.set_global() 
    
    # Gridlines
gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
gl.xlabels_bottom = True
gl.ylabels_left = True

# Add land feature
ax.add_feature(cfeature.NaturalEarthFeature(
    'physical', 'land', '110m',
    linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
ax.set_title("(GRACE - v4r5) Residual minus (GRACE - ctrl) Residual", fontsize=12)


# Colorbar
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.1f}' for b in boundaries])
cbar.set_label("(cm)")

plt.savefig("Fig4d_minus_Fig4e_residual.png", dpi=500, bbox_inches='tight')

plt.show()
#plt.tight_layout(rect=[0, 0.1, 1, 1])


#dif.plot()

### Cost Plots

Compute Cost Functions\
cost = variance(Model-Obs) / variance(Data Error)\
Use both v4r5 and ctrl

In [None]:
import xarray as xr
import numpy as np

# Define the global 0.5° grid
lat_common = np.arange(-89.75, 90.0, 0.5)
lon_common = np.arange(-179.75, 180.0, 0.5)

common_grid = xr.Dataset(
    {
        "lat": (["lat"], lat_common),
        "lon": (["lon"], lon_common)
    }
)


In [None]:
import xesmf as xe

# ECCO regridding to common grid
regrid_ecco = xe.Regridder(
    ecco_v4r5_common,
    common_grid,
    method="bilinear",
    extrap_method="nearest_s2d"
    #reuse_weights=True
)
ecco_v4r5_common_rg = regrid_ecco(ecco_v4r5_common)

regrid_ecco = xe.Regridder(
    ecco_ctrl_common,
    common_grid,
    method="bilinear",
    extrap_method="nearest_s2d"
    #reuse_weights=True
)
ecco_ctrl_common_rg = regrid_ecco(ecco_ctrl_common)


# GRACE regridding to common grid
regrid_grace = xe.Regridder(
    grace_common,
    common_grid,
    method="bilinear",
    extrap_method="nearest_s2d"
    #reuse_weights=True
)
grace_common_rg = regrid_grace(grace_common)


In [None]:
grace_error_rg = regrid_grace(grace_error)
GRACE_error_var = grace_error_rg ** 2


In [None]:
# Variances over time
v4r5_min_grace_var = (ecco_v4r5_common_rg - grace_common_rg).var(dim="time")
ctrl_min_grace_var = (ecco_ctrl_common_rg - grace_common_rg).var(dim="time")

# Cost fields
cost_v4r5 = v4r5_min_grace_var / GRACE_error_var
cost_ctrl = ctrl_min_grace_var / GRACE_error_var


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm


# Dataset and metadata
data_list = [cost_ctrl, cost_v4r5]
title_list = ["Model Cost (ctrl)", 
              "Model Cost (v4r5)"]
labels = ['(a)', '(b)']

# Choose number of discrete bins
n_bins = 40   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("jet", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(0, 4 + 0.4, 0.4)
#boundaries2 = np.arange(0, 1.5 + 0.2, 0.2)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


# Create subplots
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(6, 6.5), dpi=300,
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)},
                         gridspec_kw={'hspace': 0.01})  # Reduce from default ~0.2


axes = axes.flatten()

for i, (ax, data, title) in enumerate(zip(axes, data_list, title_list)):
    # Choose colormap
    
    pc = ax.pcolormesh(data.lon, data.lat, data,
                        transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm_disc)
   

    ax.set_global()
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i == 1 # Only bottom row
    gl.ylabels_left = True     # Only left column
    gl.right_labels = False
    gl.top_labels = False
    gl.xlabel_style = {'fontsize': 7}
    gl.ylabel_style = {'fontsize': 7}

    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'))

    ax.set_title(title, fontsize=8, fontweight='bold')
    ax.text(-0.09, 1.01, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # Add vertical colorbar to the right of each panel
    cbar = fig.colorbar(pc, ax=ax, orientation='vertical', pad=0.05, fraction=0.05, shrink=0.85)
    
    cbar.set_ticks(boundaries)
    cbar.ax.set_yticklabels([f'{b:.1f}' for b in boundaries],fontsize=6)
        #cbar.set_label("OBP Std. Dev. (cm)", fontsize=8)
    #cbar.set_ticks(boundaries_diff)
    #cbar.ax.set_yticklabels([f'{b:.3f}' for b in boundaries_diff], fontsize=7)

        #cbar.set_label("GRACE Std. Dev. Diff. (cm)", fontsize=8)
                # Add zero-line contour
    


plt.tight_layout()
plt.savefig("Model_Cost_Maps_GRACEV2.png", dpi=500, bbox_inches='tight')
plt.show()


### Cost Histograms

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Flatten and drop NaNs/Infs
c1 = np.asarray(cost_ctrl).ravel()
c2 = np.asarray(cost_v4r5).ravel()
c1 = c1[np.isfinite(c1)]
c2 = c2[np.isfinite(c2)]

# Common bin edges from the combined data (robust to outliers)
lo = np.percentile(np.concatenate([c1, c2]), 1)
hi = np.percentile(np.concatenate([c1, c2]), 99)
bins = np.linspace(lo, hi, 40)

fig, axes = plt.subplots(1, 2, figsize=(10,4), dpi=150, sharey=True)

axes[0].hist(c1, bins=bins, edgecolor='gray')
axes[0].set_title("Cost (CTRL)")
axes[0].set_xlabel("cost_ctrl")
axes[0].set_ylabel("Count")
axes[0].grid(True, ls="--", lw=0.4, alpha=0.5)

axes[1].hist(c2, bins=bins, edgecolor='gray')
axes[1].set_title("Cost (V4r5)")
axes[1].set_xlabel("cost_v4r5")
axes[1].grid(True, ls="--", lw=0.4, alpha=0.5)

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# flatten + drop NaNs/Infs
c1 = np.asarray(cost_ctrl).ravel()
c2 = np.asarray(cost_v4r5).ravel()
c1 = c1[np.isfinite(c1)]
c2 = c2[np.isfinite(c2)]

# choose a robust range from data
lo = np.percentile(np.concatenate([c1, c2]), 1)
hi = np.percentile(np.concatenate([c1, c2]), 99)

# snap range to a 0.05 grid
w = 0.05
lo = np.floor(lo / w) * w
hi = np.ceil(hi / w) * w

# bins every 0.05
bins = np.arange(lo, hi + w, w)

fig, axes = plt.subplots(1, 2, figsize=(10,4), dpi=150, sharey=True)

axes[0].hist(c1, bins=bins, edgecolor='none')
axes[0].set_title("Cost (CTRL)")
axes[0].set_xlabel("cost_ctrl"); axes[0].set_ylabel("Count"); axes[0].grid(True, ls="--", lw=0.4, alpha=0.5)

axes[1].hist(c2, bins=bins, edgecolor='none')
axes[1].set_title("Cost (V4r5)")
axes[1].set_xlabel("cost_v4r5"); axes[1].grid(True, ls="--", lw=0.4, alpha=0.5)

plt.tight_layout(); plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter


# flatten + drop NaNs/Infs
c1 = np.asarray(cost_ctrl).ravel()
c2 = np.asarray(cost_v4r5).ravel()
c1 = c1[np.isfinite(c1)]
c2 = c2[np.isfinite(c2)]

# choose a robust range from data
lo = np.percentile(np.concatenate([c1, c2]), 1)
hi = np.percentile(np.concatenate([c1, c2]), 99)

# snap range to a 0.05 grid
w = 0.1
lo = np.floor(lo / w) * w
hi = np.ceil(hi / w) * w

# bins every 0.05
bins = np.arange(lo, hi + w, w)

fig, axes = plt.subplots(1, 2, figsize=(10,4), dpi=300, sharey=True)

axes[0].hist(c1, bins=bins, edgecolor='gray')
axes[0].set_title("Cost (CTRL)")
axes[0].set_xlabel("cost_ctrl"); axes[0].set_ylabel("Count"); axes[0].grid(True, ls="--", lw=0.4, alpha=0.5)

axes[1].hist(c2, bins=bins, edgecolor='gray')
axes[1].set_title("Cost (V4r5)")
axes[1].set_xlabel("cost_v4r5"); axes[1].grid(True, ls="--", lw=0.4, alpha=0.5)

tick_step = 0.5
xticks = np.arange(lo, hi + tick_step, tick_step)

for ax in axes:
    ax.tick_params(axis='both', which='major', labelsize=4)   # major ticks
    ax.tick_params(axis='both', which='minor', labelsize=1)   # minor ticks (if used)

    ax.set_xticks(xticks)                 # positions
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))  # 1 decimal

plt.tight_layout(); plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# flatten + clean
c1 = np.asarray(cost_ctrl).ravel()
c2 = np.asarray(cost_v4r5).ravel()
c1 = c1[np.isfinite(c1)]
c2 = c2[np.isfinite(c2)]

# common bins (adjust as you like)
w = 0.1
lo = np.floor(np.percentile(np.concatenate([c1, c2]), 1) / w) * w
hi = np.ceil(np.percentile(np.concatenate([c1, c2]), 99) / w) * w
bins = np.arange(lo, hi + w, w)

fig, axes = plt.subplots(1, 2, figsize=(10,4), dpi=300, sharey=True)

def hist_with_peak(ax, data, bins, title, xlabel):
    # draw histogram but also get counts
    counts, edges, patches = ax.hist(data, bins=bins, edgecolor='none')
    centers = 0.5 * (edges[:-1] + edges[1:])
    i_max = int(np.argmax(counts))
    peak_x = centers[i_max]
    peak_y = counts[i_max]
    # annotate
    ax.annotate(
        #f"peak: {peak_y:.0f}\nbin @ {peak_x:.2f}",
        f"highest frquency bin at {peak_x:.2f}",
        xy=(0.95,0.9),                # x, y in axes coords
        xycoords="axes fraction",
        #xy=(peak_x, peak_y),
        #xytext=(peak_x, peak_y * 1.08),
        #xy=(peak_x, -2000),
        #xytext=(peak_x, -2000),
        ha="right", va="bottom",
        #arrowprops=dict(arrowstyle="<-", lw=0.8)
    )
    # style
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.grid(True, ls="--", lw=0.4, alpha=0.5)

hist_with_peak(axes[0], c1, bins, "Cost (CTRL)", "cost_ctrl")
hist_with_peak(axes[1], c2, bins, "Cost (V4r5)", "cost_v4r5")
axes[0].set_ylabel("Count")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# flatten + drop NaNs/Infs
c1 = np.asarray(cost_ctrl).ravel()
c2 = np.asarray(cost_v4r5).ravel()
c1 = c1[np.isfinite(c1)]
c2 = c2[np.isfinite(c2)]

# choose a robust range from data
lo = np.percentile(np.concatenate([c1, c2]), 1)
hi = np.percentile(np.concatenate([c1, c2]), 99)

# snap range to a 0.05 grid
w = 0.2
lo = np.floor(lo / w) * w
hi = np.ceil(hi / w) * w

# bins every 0.05
bins = np.arange(lo, hi + w, w)

fig, axes = plt.subplots(1, 2, figsize=(10,4), dpi=150, sharey=True)

axes[0].hist(c1, bins=bins, edgecolor='none')
axes[0].set_title("Cost (CTRL)")
axes[0].set_xlabel("cost_ctrl"); axes[0].set_ylabel("Count"); axes[0].grid(True, ls="--", lw=0.4, alpha=0.5)

axes[1].hist(c2, bins=bins, edgecolor='none')
axes[1].set_title("Cost (V4r5)")
axes[1].set_xlabel("cost_v4r5"); axes[1].grid(True, ls="--", lw=0.4, alpha=0.5)

plt.tight_layout(); plt.show()


### EOF Analysis on ECCO Adjustments and OBP Residuals (GRACE-V4r5)

Nature of the EOF analysis/
 - GRACE (Raw)
 - V4R5 (Raw)
 - ctrl (Raw)
 - GRACE (Detrended)
 - V4R5 (Detrended)
 - ctrl (Detrended)
 - GRACE (Residual)
 - V4R5 (Residual)
 - ctrl (Residual)
 - GRACE - V4r5
 - V4r5 - ctrl
 - GRACE - ctrl

In [None]:
from eofs.xarray import Eof

solver = Eof(GRACE_common.fillna(0))

eofs = solver.eofsAsCorrelation(neofs=5)  # Leading EOF patterns
pcs = solver.pcs(npcs=5, pcscaling=1)     # Corresponding PCs
variance_fraction = solver.varianceFraction(neigs=40)  # Variance explained

pcs_normalized = pcs / pcs.std(dim='time')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap, BoundaryNorm

def plot_eofs_with_pcs(
    eofs,                       # list/array of 2D xarray DataArrays (lat x lon) OR DataArray with mode dim
    pcs,                        # 2D array-like (time x modes) or xarray DataArray with dims ('time','mode')
    title="title",
    n_modes=5,
    variance_fraction=None,     # list/array of length >= n_modes, optional
    cmap="coolwarm",
    boundaries=None,            # 1D array of bin edges for discrete colormap (shared by all maps)
    labels=None,                # e.g., ['(a)','(b)', ...]; auto-made if None
    central_longitude=0,
    figsize=(14, 16),
    dpi=300,
    save_path=None
):
    """
    Plot EOF maps (left column) and PC time series (right column).
    - Single shared bottom horizontal colorbar for all maps.
    - Uses a discrete colormap built from `cmap` and `boundaries` if given.
    """

    # --- Grab first n_modes ---
    # Accept either a list of DataArrays or a DataArray with a 'mode' dim
    if hasattr(eofs, "dims") and "mode" in getattr(eofs, "dims", ()):
        map_list = [eofs.sel(mode=i).squeeze() for i in range(n_modes)]
    else:
        map_list = list(eofs[:n_modes])

    # PCs: accept xarray or numpy; aim for shape (time, mode)
    #if hasattr(pcs, "dims"):
    #    if "mode" in pcs.dims and "time" in pcs.dims:
    #        pcs_2d = pcs.transpose("time", "mode").values
    #    else:
    #        raise ValueError("`pcs` must have dims ('time','mode') if xarray.")
    #else:
    #    pcs_2d = np.asarray(pcs)
    #    if pcs_2d.ndim != 2:
    #        raise ValueError("`pcs` must be 2D (time x modes).")
    #pcs_2d = pcs_2d[:, :n_modes]

    if labels is None:
        labels = [f"({chr(97+i)})" for i in range(n_modes)]

    # --- Build discrete colormap (optional) ---
    if boundaries is None:
        boundaries = np.arange(-1, 1 + 0.25, 0.25)  # default edges
    n_bins = len(boundaries) - 1
    base = plt.cm.get_cmap(cmap, n_bins)
    cmap_disc = ListedColormap([base(i) for i in range(n_bins)])
    norm_disc = BoundaryNorm(boundaries, ncolors=n_bins)

    # --- Helper to find coord names robustly ---
    def _get_xy(da):
        for lon_name in ("lon", "longitude", "x"):
            if lon_name in da.coords:
                LON = da[lon_name]
                break
        else:
            raise ValueError("Longitude coordinate not found.")
        for lat_name in ("lat", "latitude", "y"):
            if lat_name in da.coords:
                LAT = da[lat_name]
                break
        else:
            raise ValueError("Latitude coordinate not found.")
        return LAT, LON

    # --- Figure & layout ---
    fig = plt.figure(figsize=figsize, dpi=dpi)
    gs  = fig.add_gridspec(nrows=n_modes, ncols=2, width_ratios=[1, 1], hspace=0.2, wspace=0.05)

    proj  = ccrs.PlateCarree(central_longitude=central_longitude)
    pcart = ccrs.PlateCarree()

    map_axes, ts_axes = [], []
    qm_last = None

    for i in range(n_modes):
        ax_map = fig.add_subplot(gs[i, 0], projection=proj)
        ax_ts  = fig.add_subplot(gs[i, 1])

        da = map_list[i]
        LAT, LON = _get_xy(da)

        # Map
        qm = ax_map.pcolormesh(
            LON, LAT, da,
            transform=pcart, cmap=cmap_disc, norm=norm_disc, rasterized=True
        )
        qm_last = qm
        map_axes.append(ax_map)

        ax_map.set_global()
        ax_map.add_feature(
            cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                         linewidth=0.4, edgecolor='black', facecolor='darkgray')
        )
        gl = ax_map.gridlines(color='gray', linestyle='--', linewidth=0.2, draw_labels=False)

        if i == 4:
            try:
                gl.xlabels_bottom = True
            except Exception:
                pass

        # Map title
        if variance_fraction is not None:
            vf_pct = float(variance_fraction[i]) * 100
            ax_map.set_title(f"Mode {i+1} ({vf_pct:.2f}%)", fontsize=10)
        else:
            ax_map.set_title(f"Mode {i+1}", fontsize=10)

        ax_map.text(-0.06, 1.02, labels[i], transform=ax_map.transAxes,
                    fontsize=9, fontweight='bold', va='bottom', ha='left',
                    bbox=dict(facecolor='white', edgecolor='none', pad=1.0, alpha=0.6))

        # PC time series
        #ax_ts.plot(pcs[:, i], lw=0.8)
        #ax_ts.axhline(0, color='0.5', lw=0.6)
        #ax_ts.set_title(f"PC {i+1}", fontsize=10)
        #if i < 4:
        #    ax_ts.set_xlabel("")  # suppress x labels except bottom row
        #    ax_ts.tick_params(labelbottom=False)
        #ax_ts.grid(True, ls='--', lw=0.3, alpha=0.5)
        #ts_axes.append(ax_ts)

        ts = pcs[:, i]
        ts.plot(ax=ax_ts, lw=0.8)
        ax_ts.axhline(0, color='0.5', lw=0.6)
        ax_ts.set_title(f"PC {i+1}", fontsize=10)
        if i < 4:
            ax_ts.set_xlabel("")  # suppress x labels except bottom row
            ax_ts.tick_params(labelbottom=False)
        ax_ts.grid(True, ls='--', lw=0.3, alpha=0.5)
        ts_axes.append(ax_ts)

      
    

    # Shared horizontal colorbar under the left column
    pos_bottom_map = map_axes[-1].get_position()
    cax = fig.add_axes([pos_bottom_map.x0, pos_bottom_map.y0 - 0.055,
                        pos_bottom_map.width, 0.018])
    cb  = fig.colorbar(qm_last, cax=cax, orientation='horizontal', ticks=boundaries)
    cb.set_label("EOF amplitude (normalized units)", fontsize=9)
    cb.ax.tick_params(labelsize=8)

    # Titles & spacing
    
    fig.suptitle(title, fontsize=16, fontweight='bold', y=0.92)

# Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.98])  # Leave space for the title
    if save_path:
        fig.savefig(save_path, dpi=300)
        print(f"Saved plot to {save_path}")
        plt.show()
    else:
        plt.show()


In [None]:
plot_eofs_with_pcs(
    eofs=eofs,                         # your EOF maps (with lat/lon coords)
    pcs=pcs_normalized,                # 2D (time x modes)
    title="GRACE Raw EOFs",
    n_modes=5,
    variance_fraction=variance_fraction  # optional
)


In [None]:
%%time

# Write a for loop to compute all interested EOFs.
ds = [GRACE_common, ECCO_pb_v4r5_common, ECCO_pb_ctrl_common,
      grace_dt_sr, v4r5_dt_sr, ctrl_dt_sr,
      grace_min_v4r5, grace_min_ctrl, ctrl_min_v4r5]

titles = ["GRACE Raw EOFs", "ECCO v4r5 EOFs", "ECCO ctrl EOFs",
          "GRACE Residual EOFs", "ECCO v4r5 Residual EOFs", "ECCO ctrl Residual EOFs",
          "GRACE - v4r5 EOFs", "GRACE - ctrl EOFs", "ctrl - v4r5 EOFs"]

solvers = []
eofs =[]
pcs = []
variance_fractions = []
pcs_norm = []

for i in range(len(ds)):

    temp = ds[i]
    if temp.dims[0] != 'time':
        temp = temp.transpose('time', 'latitude', 'longitude')
    
    temp = temp.reset_coords(drop=True)  # keep only dimension coords
    solver = Eof(temp.fillna(0))
    #solvers.append(solver)

    eof = solver.eofsAsCorrelation(neofs=5)  # Leading EOF patterns
    pc = solver.pcs(npcs=5, pcscaling=1)     # Corresponding PCs
    variance_fraction = solver.varianceFraction(neigs=40)  # Variance explained

    pc_norm = pc / pc.std(dim='time')

    eofs.append(eof)
    pcs.append(pc)
    variance_fractions.append(variance_fraction)
    pcs_norm.append(pc_norm)

    


In [None]:
for i in range(len(ds)):
    name = titles[i]
    plot_eofs_with_pcs(
        eofs=eofs[i],                         # your EOF maps (with lat/lon coords)
        pcs=pcs_norm[i],                # 2D (time x modes)
        title=name,
        n_modes=5,
        variance_fraction=variance_fractions[i],  # optional
        save_path = f"./eof_figs/{name}.png"
    )
    




### ECCO Full Period EOF
full ECCO period, and with 3 versions (raw fields, detrended, and detrended with MSC also removed)\
"ctrl - v4r5 EOFs”

In [None]:
v4r5  = ECCO_pb_v4r5_pp_rg
ctrl = ECCO_pb_ctrl_pp_rg

v4r5_dt = detrend_dim(v4r5)
ctrl_dt = detrend_dim(ctrl)

v4r5_sr = remove_seasonal_cycle(v4r5_dt)
ctrl_sr = remove_seasonal_cycle(ctrl_dt)

res    = ctrl - v4r5
res_dt = ctrl_dt - v4r5_dt
res_sr = ctrl_sr - v4r5_sr

ds = [res, res_dt, res_sr]


solvers = []
eofs =[]
pcs = []
variance_fractions = []
pcs_norm = []

for i in range(len(ds)):

    temp = ds[i]
    if temp.dims[0] != 'time':
        temp = temp.transpose('time', 'latitude', 'longitude')
    
    temp = temp.reset_coords(drop=True)  # keep only dimension coords
    solver = Eof(temp.fillna(0))
    #solvers.append(solver)

    eof = solver.eofsAsCorrelation(neofs=5)  # Leading EOF patterns
    pc = solver.pcs(npcs=5, pcscaling=1)     # Corresponding PCs
    variance_fraction = solver.varianceFraction(neigs=40)  # Variance explained

    pc_norm = pc / pc.std(dim='time')

    eofs.append(eof)
    pcs.append(pc)
    variance_fractions.append(variance_fraction)
    pcs_norm.append(pc_norm)



In [None]:
titles = ["ctrl - v4r5 Raw",
          "ctrl - v4r5 Detrended",
          "ctrl - v4r5 Residual"]

for i in range(len(ds)):
    name = titles[i]
    plot_eofs_with_pcs(
        eofs=eofs[i],                         # your EOF maps (with lat/lon coords)
        pcs=pcs_norm[i],                # 2D (time x modes)
        title=name,
        n_modes=5,
        variance_fraction=variance_fractions[i],  # optional
        save_path = f"./eof_figs/ECCO_full_period_{name}.png"
    )

In [None]:
v4r5  = ECCO_pb_v4r5_pp_rg
ctrl = ECCO_pb_ctrl_pp_rg

In [None]:
v4r5

In [None]:
import xarray as xr
import numpy as np

# Your full field
da = v4r5  # DataArray with name 'pb'
#da = ctrl

# Define the subperiod to use for estimating the trend
t0 = "2002-04-01"
t1 = "2019-12-01"

sub = da.sel(time=slice(t0, t1))   # (time, latitude, longitude)

# IMPORTANT: use the DataArray method .polyfit, not xr.polyfit
coefs = sub.polyfit(dim="time", deg=1)   # returns a Dataset

# For a DataArray named 'pb', the coefficient variable is usually named:
# 'pb_polyfit_coefficients'
coef_da = coefs["polyfit_coefficients"]   # dims: degree, latitude, longitude

# Now evaluate the fitted trend at *all* times of the full record
trend_full = xr.polyval(da["time"], coef_da)  # same dims as da: (time, latitude, longitude)

# Detrended field: remove the trend estimated from the subperiod
v4r5_reg_detr = da - trend_full
#ctrl_reg_detr = da - trend_full


In [None]:
print(v4r5_reg_detr, ctrl_reg_detr)

In [None]:
dif_reg_dt = ctrl_reg_detr - v4r5_reg_detr
ds = [dif_reg_dt]

In [None]:
solvers = []
eofs =[]
pcs = []
variance_fractions = []
pcs_norm = []

for i in range(len(ds)):

    temp = ds[i]
    if temp.dims[0] != 'time':
        temp = temp.transpose('time', 'latitude', 'longitude')
    
    temp = temp.reset_coords(drop=True)  # keep only dimension coords
    solver = Eof(temp.fillna(0))
    #solvers.append(solver)

    eof = solver.eofsAsCorrelation(neofs=5)  # Leading EOF patterns
    pc = solver.pcs(npcs=5, pcscaling=1)     # Corresponding PCs
    variance_fraction = solver.varianceFraction(neigs=40)  # Variance explained

    pc_norm = pc / pc.std(dim='time')

    eofs.append(eof)
    pcs.append(pc)
    variance_fractions.append(variance_fraction)
    pcs_norm.append(pc_norm)


In [None]:
titles = [
          "ctrl - v4r5 Detrended (Grace Period)",
          ]

for i in range(len(ds)):
    name = titles[i]
    plot_eofs_with_pcs(
        eofs=eofs[i],                         # your EOF maps (with lat/lon coords)
        pcs=pcs_norm[i],                # 2D (time x modes)
        title=name,
        n_modes=5,
        variance_fraction=variance_fractions[i],  # optional
        save_path = f"./eof_figs/ECCO_full_period_{name}.png"
    )

In [None]:
coefs

### Barystatic SL Figure

In [None]:
r5_dir = '/glade/work/mengnanz/V4r5_ctrl/OBP_mon/'
data_list = sorted(os.listdir(r5_dir)) 

r5_month = np.arange(1992,2020,1/12)

import gc
# Create a mask of eccor5_grid where every cell with value > 0 is set to 1.
mask = np.sum(eccor5_grid.maskC, axis=0)
for i in np.arange(13):
    temp = mask[i,:,:].values
    temp[temp>0] = 1
del i, temp
gc.collect()


# Read in ECCO ctrl data
pb_r5_mon = np.zeros((len(r5_month),13,90,90))*np.nan
#maskW = mask
for i in tqdm(np.arange(len(data_list))):
    data = ecco.read_llc_to_tiles(r5_dir, data_list[i],less_output=True,nk=-1)*100;
    data = data[1,:,:,:]
    #data = data*maskW
    pb_r5_mon[i,:,:,:] = data

pb_r5_mon[pb_r5_mon==0] = np.nan

total_ocn_area = np.nansum(eccor5_grid.rA*mask) 

# global-mean of pb_r4_mon, weighted by area
total_ocn_area = np.nansum(eccor5_grid.rA*mask)  # global-mean for ocean area, land is zero
weight = eccor5_grid.rA/total_ocn_area
weight_expanded = np.expand_dims(weight, axis=0)
weight = np.tile(weight_expanded, (pb_r5_mon.shape[0],1,1,1))
glob_mean = np.nansum(np.nansum(np.nansum(pb_r5_mon*weight,axis=-1),axis=-1),axis=-1)
glob_mean_eccor5 = np.copy(glob_mean)
glob_mean = np.tile(glob_mean[:,np.newaxis,np.newaxis,np.newaxis],(1,13,90,90))

# Save r5_with_ext_month to an xarray dataset

# Extract coordinates from the grid
i = eccor5_grid['i']  # Shape: (90,)
j = eccor5_grid['j']  # Shape: (90,)
tile = eccor5_grid['tile']  # Shape: (13,)
time=(["time"],r5_month)

ds = xr.Dataset(
    data_vars=dict(
        pb=(["time","tile", "j", "i"], pb_r5_mon, {"units": "cm"}),
    ),
    coords=dict(
        time=time,
        tile=tile,
        j=j,
        i=i
    ),
    attrs=dict(description="ECCO pb, cm", units='cm'),
)

# Save the dataset to a NetCDF file
#ds.to_netcdf(os.path.join(output_dir, 'pb_ECCO_ctrl.nc'))

# Generate the new time coordinate
new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
ds = ds.assign_coords(time=new_time)

ecco_ctrl_ts = ds.pb.mean(dim=['tile', 'j', 'i'], skipna=True)

ecco_ctrl_ts = ecco_ctrl_ts.sel(time = slice("2002-04-17", "2019-12-01"))
temp = ecco_ctrl_ts.mean(dim = "time")
ecco_ctrl_ts_anom = ecco_ctrl_ts - temp

In [None]:
v4r5_raw = xr.open_dataset("/glade/work/mengnanz/V4r5/GMSL.nc")
v4r5_raw = v4r5_raw*100
v4r5_raw = v4r5_raw.assign_coords(time = new_time)
v4r5_raw = v4r5_raw.sel(time = slice("2002-04-17", "2019-12-01"))
temp = v4r5_raw.mean(dim = "time")
v4r5_raw_anom = v4r5_raw - temp

In [None]:
v4r5_raw

In [None]:
grace_ts = xr.open_dataset("ocean_mass_GRACE_200204_202502.nc")

import numpy as np
import pandas as pd

def decimal_years_to_actual_date_midday(decimal_years):
    years = np.floor(decimal_years).astype(int)
    remainders = decimal_years - years

    # Compute number of days in the year (handles leap years)
    start_of_year = pd.to_datetime(years, format='%Y')
    start_of_next_year = pd.to_datetime(years + 1, format='%Y')
    days_in_year = (start_of_next_year - start_of_year).days

    # Round to nearest whole day
    day_offsets = np.round(remainders * days_in_year).astype(int)

    # Construct datetime with 12:00 PM and second-level precision
    datetimes = [
        pd.Timestamp(year=year, month=1, day=1) + pd.Timedelta(days=offset) + pd.Timedelta(hours=12)
        for year, offset in zip(years, day_offsets)
    ]

    # Return without fractional seconds
    return pd.to_datetime(datetimes).astype("datetime64[s]")

# Apply to GRACE xarray dataset
grace_ts = grace_ts.assign_coords(time=decimal_years_to_actual_date_midday(grace_ts.time.values))

grace_ts = grace_ts /10
grace_ts = grace_ts.sel(time = slice("2002-04-17", "2019-12-01"))
avg = grace_ts.mean()
grace_ts_anom = grace_ts - avg

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

# --- figure & axes (smaller, compact) ---
fig, ax = plt.subplots(figsize=(6, 2.8), dpi=500)

# --- series ---
ax.plot(
    v4r5_raw_anom.time,
    v4r5_raw_anom.global_mean_barystatic_sea_level_anomaly,
    label="ECCO v4r5",
    color="#ff7f00", lw=1.6, ls="-"
)
ax.plot(
    ecco_ctrl_ts_anom.time,
    ecco_ctrl_ts_anom,
    label="ECCO Control",
    color="#33a02c", lw=1.6, ls="-"
)
ax.plot(
    grace_ts_anom.time,
    grace_ts_anom.ocean_mass,
    label="GRACE",
    color="#1f78b4", lw=0, marker="o", ms=2.8, mew=0
)

# --- axes formatting ---
ax.set_ylabel("Ocean Bottom Pressure Spatial Avg. (cm)", fontsize=9)
ax.set_xlim(pd.Timestamp("2002-04-01"), pd.Timestamp("2019-12-01"))
ax.set_ylim(-17, 17)   # adjust if needed

# ticks & grid
ax.xaxis.set_major_locator(mdates.YearLocator(2))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.tick_params(axis="both", labelsize=8)
ax.grid(True, ls="--", lw=0.4, alpha=0.6)

# legend (compact, top-right)
leg = ax.legend(
    loc="upper right", frameon=True, fontsize=8,
    handlelength=2.2, borderaxespad=0.6
)
leg.get_frame().set_alpha(0.9)

plt.tight_layout()
plt.savefig("OBP_TS_only_Fig.png", dpi=500, bbox_inches='tight')
plt.show()


### RMS Variant Plots

In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

def detrend_per_pixel(da: xr.DataArray, time_dim: str = "time", deg: int = 1) -> xr.DataArray:
    """
    Remove a polynomial trend (default linear) at each pixel along `time_dim`.

    Parameters
    ----------
    da : xr.DataArray
        Input with a time dimension and 2D spatial dims.
    time_dim : str
        Name of the time dimension (default "time").
    deg : int
        Polynomial degree for the trend (1 = linear).

    Returns
    -------
    xr.DataArray
        Same shape as `da`, with the per-pixel trend removed.
    """
    if time_dim not in da.dims:
        raise ValueError(f"`{time_dim}` not found in dims: {da.dims}")

    # Use a numeric index for stable fitting (works with datetime)
    idx = xr.DataArray(
        np.arange(da.sizes[time_dim], dtype="float64"),
        dims=(time_dim,),
        coords={time_dim: da[time_dim]}
    )

    # Temporarily fit on the numeric index (polyfit uses the dim's coordinate values)
    da_idx = da.assign_coords({time_dim: idx})

    # Fit polynomial along time for every pixel (NaNs are skipped)
    pf = da_idx.polyfit(dim=time_dim, deg=deg, skipna=True)  # returns Dataset with "polyfit_coefficients"

    # Evaluate the fitted trend at each time index
    trend_idx = xr.polyval(da_idx[time_dim], pf.polyfit_coefficients)

    # Put original time coordinates back on the trend
    trend = trend_idx.assign_coords({time_dim: da[time_dim]})

    # Subtract trend (preserves attrs)
    detrended = da - trend
    detrended.attrs.update(da.attrs)
    detrended.attrs["detrended"] = f"per-pixel polynomial (deg={deg}) along {time_dim}"
    return detrended

def broadcast_weights(da):
    """cos(lat) weights, normalized over ocean for each time snapshot."""
    w1d = xr.DataArray(np.cos(np.deg2rad(da.latitude)), coords={"latitude": da.latitude}, dims=("latitude",))
    w2d = w1d.broadcast_like(da.isel(time=0)).rename({"latitude":"latitude","longitude":"longitude"})
    return w2d


import numpy as np
import xarray as xr

def band_spatial_rms_anom(
    da: xr.DataArray,
    lat_min: float = -25,
    lat_max: float =  25,
    lat_name: str = "latitude",
    lon_name: str = "longitude",
    time_name: str = "time",
):
    """
    Area-weighted (cos-lat) spatial RMS of anomalies per time step,
    restricted to a latitude band [lat_min, lat_max].

    Returns: DataArray over `time_name`.
    """
    # 1) Select latitude band robustly (works for ascending/descending lat)
    lat = da[lat_name]
    if lat[0] < lat[-1]:
        sub = da.sel({lat_name: slice(lat_min, lat_max)})
    else:
        sub = da.sel({lat_name: slice(lat_max, lat_min)})

    # If band is empty, return all-NaN time series with same time coord
    if sub.sizes.get(lat_name, 0) == 0:
        return xr.full_like(da.isel({time_name: slice(None), lat_name: 0, lon_name: 0}), np.nan).drop_vars([lat_name, lon_name])

    # 2) Cos(lat) weights -> 1D then broadcast to 2D
    wlat = xr.DataArray(np.cos(np.deg2rad(sub[lat_name])),
                        dims=(lat_name,), coords={lat_name: sub[lat_name]})
    w2d  = wlat.broadcast_like(sub.isel({time_name: 0}, drop=True))

    # 3) Mask weights where data are missing (per time)
    mask = xr.where(np.isfinite(sub), 1.0, np.nan)
    w    = w2d * mask

    # 4) Weighted spatial mean per time (to form anomalies)
    den  = w.sum(dim=(lat_name, lon_name), skipna=True)
    num_mean = (sub * w).sum(dim=(lat_name, lon_name), skipna=True)
    mean = num_mean / den

    # 5) Anomalies and weighted RMS per time
    anom = sub - mean
    num_var = ((anom ** 2) * w).sum(dim=(lat_name, lon_name), skipna=True)
    rms = np.sqrt(num_var / den)

    return rms  # dims: (time)


import os
import numpy as np
import matplotlib.pyplot as plt

def plot_rms_with_grace_shading(
    rms_v4r5,
    rms_ctrl,
    rms_diff,
    grace_times,
    title="RMS of ECCO v4r5, ctrl, and (v4r5−ctrl), With Trend (Global)",
    ylabel="RMS, cm",
    xlabel="Months",
    shade_label="GRACE Availability",
    shade_halfwidth="15D",     # '15D' or np.timedelta64(15,'D') or integer days
    colors=None,               # {"v4r5":"#1f77b4","ctrl":"#ff7f0e","diff":"#2ca02c"}
    alpha=0.10,                # shading opacity
    ylim_bottom=0.0,
    ax=None,
    dpi=500,
    figsize=(12, 6),
    savepath=None,             # <- NEW
    save_kwargs=None           # dict passed to plt.savefig
):
    """
    Plot three RMS time series and shade GRACE-available months.
    Optionally save to `savepath`.
    """
    def _get_time_and_vals(obj):
        if hasattr(obj, "coords") and "time" in getattr(obj, "coords", {}):
            return obj["time"].values, np.asarray(obj)
        if hasattr(obj, "index"):
            return obj.index.values, np.asarray(obj)
        raise TypeError("Input must be xarray.DataArray (with 'time') or pandas.Series with datetime index.")

    def _to_timedelta64(x):
        if isinstance(x, np.timedelta64):
            return x
        if isinstance(x, str):
            x = x.strip().upper()
            if x.endswith("D"):
                return np.timedelta64(int(x[:-1]), "D")
            raise ValueError("Only day strings like '15D' are supported.")
        if isinstance(x, (int, float)):
            return np.timedelta64(int(x), "D")
        raise ValueError("shade_halfwidth must be np.timedelta64, '15D' string, or integer days.")

    t_v4r5, y_v4r5 = _get_time_and_vals(rms_v4r5)
    t_ctrl,  y_ctrl  = _get_time_and_vals(rms_ctrl)
    t_diff,  y_diff  = _get_time_and_vals(rms_diff)

    t_min = np.min([t_v4r5.min(), t_ctrl.min(), t_diff.min()])
    t_max = np.max([t_v4r5.max(), t_ctrl.max(), t_diff.max()])

    grace_times = np.asarray(grace_times)
    halfw = _to_timedelta64(shade_halfwidth)

    if colors is None:
        colors = {"v4r5": "#1f77b4", "ctrl": "#ff7f0e", "diff": "#2ca02c"}

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        created_fig = True
    else:
        fig = ax.figure

    # Shading
    first = True
    for t in grace_times:
        if (t < t_min) or (t > t_max):
            continue
        ax.axvspan(t - halfw, t + halfw, color="k", alpha=alpha, lw=0, zorder=0,
                   label=shade_label if first else None)
        first = False

    # Lines
    ax.plot(t_v4r5, y_v4r5, label="v4r5", color=colors["v4r5"], lw=1.2)
    ax.plot(t_ctrl,  y_ctrl,  label="ctrl", color=colors["ctrl"], lw=1.2)
    ax.plot(t_diff,  y_diff,  label="v4r5-ctrl", color=colors["diff"], lw=1.2)

    # Cosmetics
    ax.set_title(title, fontsize=13)
    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)
    if ylim_bottom is not None:
        ax.set_ylim(bottom=ylim_bottom)
    ax.grid(True, ls="--", lw=0.4, alpha=0.5)
    ax.legend(frameon=True)

    # Save if requested
    if savepath is not None:
        os.makedirs(os.path.dirname(savepath) or ".", exist_ok=True)
        skw = {"dpi": dpi, "bbox_inches": "tight"}
        if isinstance(save_kwargs, dict):
            skw.update(save_kwargs)
        fig.savefig(savepath, **skw)

    if created_fig:
        plt.tight_layout()
        plt.show()

    return fig, ax



In [None]:
# Detrended RMS Plots
grace_dt = detrend_per_pixel(aligned_GRACE_pp_rg)          # (time, lat, lon)
v4r5_dt  = detrend_per_pixel(ECCO_pb_v4r5_pp_rg)
ctrl_dt  = detrend_per_pixel(ECCO_pb_ctrl_pp_rg)
diff_dt = v4r5_dt - ctrl_dt

lat_mins = [-90, 0, -90, -25, 25, -90, 65]
lat_maxs = [90, 90, 0, 25, 90, -25, 90]

rms_v4r5_dt = []
rms_ctrl_dt = []
rms_diff_dt = []

for i in range(len(lat_mins)):
    rms_v4r5 = band_spatial_rms_anom(v4r5_dt, 
                                     lat_min=lat_mins[i], 
                                     lat_max=lat_maxs[i])
    rms_v4r5_dt.append(rms_v4r5)
    rms_ctrl = band_spatial_rms_anom(ctrl_dt, 
                                     lat_min=lat_mins[i], 
                                     lat_max=lat_maxs[i])
    rms_ctrl_dt.append(rms_ctrl)
    rms_diff = band_spatial_rms_anom(diff_dt, 
                                     lat_min=lat_mins[i], 
                                     lat_max=lat_maxs[i])
    rms_diff_dt.append(rms_diff)

title_list = ["Global", "Northern Hemisphere", "Southern Hemisphere",
              "Tropical Band (-25 to 25)", "25N to 65N", "25S to 90S",
              "Arctic Circle (65N to 90N)"]

fn_list = ["glob", "nh", "sh", "tropical", "midlat_nh", "midlat_sh", "arctic"]

In [None]:
for i in range(len(rms_v4r5_dt)):
    
    temp = title_list[i]
    temp2 = fn_list[i]
    
    plot_rms_with_grace_shading(
        rms_v4r5_dt[i],
        rms_ctrl_dt[i], 
        rms_diff_dt[i],
    grace_times=aligned_GRACE_pp_rg.time.values,
    title=f"{temp} RMS of ECCO v4r5, ctrl, and (v4r5−ctrl), No Trend",
    savepath=f"./RMS_plots/rms_timeseries_{temp2}_dt.png"
)




In [None]:
# With Trend
grace_dt = aligned_GRACE_pp_rg       # (time, lat, lon)
v4r5_dt  = ECCO_pb_v4r5_pp_rg
ctrl_dt  = ECCO_pb_ctrl_pp_rg
diff_dt = v4r5_dt - ctrl_dt

rms_v4r5_t = []
rms_ctrl_t = []
rms_diff_t = []

for i in range(len(lat_mins)):
    rms_v4r5 = band_spatial_rms_anom(v4r5_dt, 
                                     lat_min=lat_mins[i], 
                                     lat_max=lat_maxs[i])
    rms_v4r5_t.append(rms_v4r5)
    
    rms_ctrl = band_spatial_rms_anom(ctrl_dt, 
                                     lat_min=lat_mins[i], 
                                     lat_max=lat_maxs[i])
    rms_ctrl_t.append(rms_ctrl)
    
    rms_diff = band_spatial_rms_anom(diff_dt, 
                                     lat_min=lat_mins[i], 
                                     lat_max=lat_maxs[i])
    rms_diff_t.append(rms_diff)

In [None]:
for i in range(len(rms_v4r5_t)):
    
    temp = title_list[i]
    temp2 = fn_list[i]
    
    plot_rms_with_grace_shading(
        rms_v4r5_t[i],
        rms_ctrl_t[i], 
        rms_diff_t[i],
    grace_times=aligned_GRACE_pp_rg.time.values,
    title=f"{temp} RMS of ECCO v4r5, ctrl, and (v4r5−ctrl), With Trend",
    savepath=f"./RMS_plots/rms_timeseries_{temp2}.png"
)


In [None]:
import matplotlib.patches as mpatches  # only needed for the proxy method below

fig, ax = plt.subplots(figsize=(12, 6), dpi=500)

# shade GRACE months and label only once for the legend
start = rms_v4r5.time.min().values
end   = rms_v4r5.time.max().values
first = True
for t in aligned_GRACE_pp_rg.time.values:
    if t < start or t > end:
        continue
    label = "GRACE Availability" if first else None
    ax.axvspan(np.datetime64(t) - np.timedelta64(15, "D"),
               np.datetime64(t) + np.timedelta64(15, "D"),
               color="k", alpha=0.1, lw=0, zorder=0, label=label)
    first = False

# lines
ax.plot(rms_v4r5.time, rms_v4r5, label="v4r5", color="#1f77b4", lw=1.2)
ax.plot(rms_ctrl.time, rms_ctrl, label="ctrl", color="#ff7f0e", lw=1.2)
ax.plot(rms_diff.time, rms_diff, label="v4r5-ctrl", color="#2ca02c", lw=1.2)

ax.set_title("RMS of ECCO v4r5, ctrl, and (v4r5−ctrl), Trends Removed (Global)", fontsize=13)
ax.set_ylabel("RMS, cm")
ax.set_xlabel("Months")
ax.set_ylim(bottom=0)
ax.grid(True, ls="--", lw=0.4, alpha=0.5)
ax.legend(frameon=True)

plt.tight_layout()
#plt.savefig("GSTM_Fig3.png", dpi=500, bbox_inches='tight')
plt.show()


### Comparison of ECCO Adjustments (std) with and without GRACE periods

In [None]:
# 1) Build a fast lookup of GRACE months
grace_times = aligned_GRACE_pp_rg.time.values  # datetime64[ns] array

# 2) Boolean masks on ECCO time axes
mask_v4r5_on  = xr.DataArray(np.isin(ECCO_pb_v4r5_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_v4r5_pp_rg.time})
mask_ctrl_on  = xr.DataArray(np.isin(ECCO_pb_ctrl_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_ctrl_pp_rg.time})

# 3) Subset ECCO to GRACE dates (“on”) and outside GRACE dates (“off”)
# v4r5
ecco_v4r5_on_grace   = ECCO_pb_v4r5_pp_rg.sel(time=mask_v4r5_on)
ecco_v4r5_off_grace  = ECCO_pb_v4r5_pp_rg.sel(time=~mask_v4r5_on)

# ctrl
ecco_ctrl_on_grace   = ECCO_pb_ctrl_pp_rg.sel(time=mask_ctrl_on)
ecco_ctrl_off_grace  = ECCO_pb_ctrl_pp_rg.sel(time=~mask_ctrl_on)

ECCO_adjt_raw_on = ecco_ctrl_on_grace - ecco_v4r5_on_grace
ECCO_adjt_raw_off = ecco_ctrl_off_grace - ecco_v4r5_off_grace

# detrend
ecco_v4r5_on_grace_dt = detrend_per_pixel(ecco_v4r5_on_grace)
ecco_v4r5_off_grace_dt = detrend_per_pixel(ecco_v4r5_off_grace)

ecco_ctrl_on_grace_dt = detrend_per_pixel(ecco_ctrl_on_grace)
ecco_ctrl_off_grace_dt = detrend_per_pixel(ecco_ctrl_off_grace)

ECCO_adjt_dt_on = ecco_ctrl_on_grace_dt - ecco_v4r5_on_grace_dt
ECCO_adjt_dt_off = ecco_ctrl_off_grace_dt- ecco_v4r5_off_grace_dt

# MSC removed (residuals)

ecco_v4r5_on_grace_sr = remove_seasonal_cycle(ecco_v4r5_on_grace_dt)
ecco_v4r5_off_grace_sr = remove_seasonal_cycle(ecco_v4r5_off_grace_dt)

ecco_ctrl_on_grace_sr = remove_seasonal_cycle(ecco_ctrl_on_grace_dt)
ecco_ctrl_off_grace_sr = remove_seasonal_cycle(ecco_ctrl_off_grace_dt)

ECCO_adjt_sr_on = ecco_ctrl_on_grace_sr - ecco_v4r5_on_grace_sr
ECCO_adjt_sr_off = ecco_ctrl_off_grace_sr- ecco_v4r5_off_grace_sr

da = [ECCO_adjt_raw_on, ECCO_adjt_raw_off,
      ECCO_adjt_dt_on, ECCO_adjt_dt_off,
      ECCO_adjt_sr_on, ECCO_adjt_sr_off]

stds = []
temp = []

for i in range(0, len(da), 2):
    std_on  = da[i].std(dim = 'time')
    std_off = da[i+1].std(dim = 'time')
    dif = std_on - std_off
    #temp = [std_on, std_off, dif]
    stds.extend([std_on, std_off, dif])
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap, BoundaryNorm

title_list = [
    "ECCO Adjustments Std. (Raw) - GRACE On", "ECCO Adjustments Std. (Raw) - GRACE Off", "On-Off Difference (Raw)",
    "ECCO Adjustments Std. (Detrended) - GRACE On", "ECCO Adjustments Std. (Detrended) - GRACE Off", "On-Off Difference (Detrended)",
    "ECCO Adjustments Std. (Residual) - GRACE On", "ECCO Adjustments Std. (Residual) - GRACE Off", "On-Off Difference (Residual)"
]
labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)']

data_list = stds  # length 9 list of DataArrays with coords latitude/longitude

# First two columns colormap
n_bins = 20
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])
boundaries = np.arange(0, 6.0 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(cmap_disc.colors))

# Third column colormap
n_bins2 = 20
boundaries2 = np.arange(-2, 2 + 0.5, 0.5)
cmap_disc2 = plt.cm.get_cmap("coolwarm", n_bins2)
norm_disc2 = BoundaryNorm(boundaries2, ncolors=n_bins2)

nrows, ncols = 3, 3
fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols, figsize=(12, 8), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

axes = axes.reshape(nrows, ncols)
pc_leftmid = None   # last QuadMesh from columns 1–2
pc_right   = None   # last QuadMesh from column 3

for i_row in range(nrows):
    for j_col in range(ncols):
        idx = i_row * ncols + j_col
        ax  = axes[i_row, j_col]
        da  = data_list[idx]

        # choose colormap group by column
        if j_col < 2:
            pc = ax.pcolormesh(
                da.longitude, da.latitude, da,
                transform=ccrs.PlateCarree(),
                cmap=cmap_disc, norm=norm_disc
            )
            pc_leftmid = pc
        else:
            pc = ax.pcolormesh(
                da.longitude, da.latitude, da,
                transform=ccrs.PlateCarree(),
                cmap=cmap_disc2, norm=norm_disc2
            )
            pc_right = pc

        ax.set_global()
        ax.add_feature(cfeature.NaturalEarthFeature(
            'physical', 'land', '110m',
            linewidth=0.5, edgecolor='black', facecolor='darkgray'
        ))

        # ==== GRIDLINE LABEL LOGIC (left column = lat labels; bottom row = lon labels) ====
        gl = ax.gridlines(color='gray', linestyle='--', linewidth=0.2, draw_labels=True)
        gl.ylabels_left   = (j_col == 0)
        gl.ylabels_right  = False
        gl.xlabels_bottom = (i_row == nrows - 1)
        gl.xlabels_top    = False
        gl.xlabel_style = {'size': 7}
        gl.ylabel_style = {'size': 7}
        # ================================================================================

        ax.set_title(title_list[idx], fontsize=11)
        ax.text(-0.07, 1., labels[idx], transform=ax.transAxes,
                fontsize=8, fontweight='bold', va='top', ha='left',
                bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Leave space for two bottom colorbars
plt.tight_layout(rect=[0, 0.12, 1, 1])

# Shared horizontal colorbar for columns 1–2
pos_left  = axes[-1, 0].get_position()
pos_mid   = axes[-1, 1].get_position()
left_x    = pos_left.x0
right_x   = pos_mid.x1
width     = right_x - left_x
cax12 = fig.add_axes([left_x, pos_left.y0 - 0.07, width, 0.022])
cb12  = fig.colorbar(pc_leftmid, cax=cax12, orientation='horizontal', ticks=boundaries)
cb12.set_label("Std. (cm)", fontsize=10)
cb12.ax.tick_params(labelsize=8)

# Shared horizontal colorbar for column 3
pos_right = axes[-1, 2].get_position()
cax3 = fig.add_axes([pos_right.x0, pos_right.y0 - 0.07, pos_right.width, 0.022])
cb3  = fig.colorbar(pc_right, cax=cax3, orientation='horizontal', ticks=boundaries2)
cb3.set_label("Difference (cm)", fontsize=10)
cb3.ax.tick_params(labelsize=8)
plt.savefig("ECO_adjustments_std_GRACE_on_off.png", dpi=500, bbox_inches='tight')
plt.show()


### In previous plots made by Liz, she uses full GRACE/ECCO common period as GRACE ON. This maybe the reason she's seeing poitive values in arctic in the (f) subplot.

In [None]:
ecco_v4r5_on_grace.isel(longitude = 35, latitude =35).plot()
ecco_v4r5_off_grace.isel(longitude = 35, latitude =35).plot()

In [None]:
ECCO_ctrl_temp = ECCO_pb_ctrl_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
ECCO_v4r5_temp = ECCO_pb_v4r5_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))


In [None]:
# 1) Build a fast lookup of GRACE months
grace_times = aligned_GRACE_pp_rg.time.values  # datetime64[ns] array

# 2) Boolean masks on ECCO time axes
mask_v4r5_on  = xr.DataArray(np.isin(ECCO_pb_v4r5_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_v4r5_pp_rg.time})
mask_ctrl_on  = xr.DataArray(np.isin(ECCO_pb_ctrl_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_ctrl_pp_rg.time})

# 3) Subset ECCO to GRACE dates (“on”) and outside GRACE dates (“off”)
# v4r5
ecco_ctrl_on_grace = ECCO_pb_ctrl_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_on_grace = ECCO_pb_v4r5_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
#ecco_v4r5_on_grace   = ECCO_pb_v4r5_pp_rg.sel(time=mask_v4r5_on)
ecco_v4r5_off_grace  = ECCO_pb_v4r5_pp_rg.sel(time=~mask_v4r5_on)

# ctrl
#ecco_ctrl_on_grace   = ECCO_pb_ctrl_pp_rg.sel(time=mask_ctrl_on)
ecco_ctrl_off_grace  = ECCO_pb_ctrl_pp_rg.sel(time=~mask_ctrl_on)

ECCO_adjt_raw_on = ecco_ctrl_on_grace - ecco_v4r5_on_grace
ECCO_adjt_raw_off = ecco_ctrl_off_grace - ecco_v4r5_off_grace

# detrend
ecco_v4r5_on_grace_dt = detrend_per_pixel(ecco_v4r5_on_grace)
ecco_v4r5_off_grace_dt = detrend_per_pixel(ecco_v4r5_off_grace)

ecco_ctrl_on_grace_dt = detrend_per_pixel(ecco_ctrl_on_grace)
ecco_ctrl_off_grace_dt = detrend_per_pixel(ecco_ctrl_off_grace)

#ECCO_adjt_dt_on = ecco_ctrl_on_grace_dt - ecco_v4r5_on_grace_dt
#ECCO_adjt_dt_off = ecco_ctrl_off_grace_dt- ecco_v4r5_off_grace_dt

ECCO_adjt_dt_on = detrend_per_pixel(ECCO_adjt_raw_on)
ECCO_adjt_dt_off = detrend_per_pixel(ECCO_adjt_raw_off)

# MSC removed (residuals)

ecco_v4r5_on_grace_sr = remove_seasonal_cycle(ecco_v4r5_on_grace_dt)
ecco_v4r5_off_grace_sr = remove_seasonal_cycle(ecco_v4r5_off_grace_dt)

ecco_ctrl_on_grace_sr = remove_seasonal_cycle(ecco_ctrl_on_grace_dt)
ecco_ctrl_off_grace_sr = remove_seasonal_cycle(ecco_ctrl_off_grace_dt)

ECCO_adjt_sr_on = remove_seasonal_cycle(ECCO_adjt_dt_on)
ECCO_adjt_sr_off = remove_seasonal_cycle(ECCO_adjt_dt_off)

da = [ECCO_adjt_raw_on, ECCO_adjt_raw_off,
      ECCO_adjt_dt_on, ECCO_adjt_dt_off,
      ECCO_adjt_sr_on, ECCO_adjt_sr_off]

stds = []
temp = []

for i in range(0, len(da), 2):
    std_on  = da[i].std(dim = 'time')
    std_off = da[i+1].std(dim = 'time')
    dif = std_on - std_off
    #temp = [std_on, std_off, dif]
    stds.extend([std_on, std_off, dif])
    

In [None]:
ecco_ctrl_on_grace = ECCO_pb_ctrl_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_on_grace = ECCO_pb_v4r5_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))

ECCO_pb_ctrl_pp_rg_dt = detrend_per_pixel(ECCO_pb_ctrl_pp_rg)
ECCO_pb_v4r5_pp_rg_dt = detrend_per_pixel(ECCO_pb_v4r5_pp_rg)

ECCO_pb_ctrl_pp_rg_sr = remove_seasonal_cycle(ECCO_pb_ctrl_pp_rg_dt)
ECCO_pb_v4r5_pp_rg_sr = remove_seasonal_cycle(ECCO_pb_v4r5_pp_rg_dt)

# 1) Build a fast lookup of GRACE months
grace_times = aligned_GRACE_pp_rg.time.values  # datetime64[ns] array

# 2) Boolean masks on ECCO time axes
mask_v4r5_on  = xr.DataArray(np.isin(ECCO_pb_v4r5_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_v4r5_pp_rg.time})
mask_ctrl_on  = xr.DataArray(np.isin(ECCO_pb_ctrl_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_ctrl_pp_rg.time})

# 3) Subset ECCO to GRACE dates (“on”) and outside GRACE dates (“off”)
# v4r5
#ecco_ctrl_on_grace = ECCO_pb_ctrl_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
#ecco_v4r5_on_grace = ECCO_pb_v4r5_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
#ecco_v4r5_on_grace   = ECCO_pb_v4r5_pp_rg.sel(time=mask_v4r5_on)
ecco_v4r5_off_grace  = ECCO_pb_v4r5_pp_rg.sel(time=~mask_v4r5_on)

# ctrl
#ecco_ctrl_on_grace   = ECCO_pb_ctrl_pp_rg.sel(time=mask_ctrl_on)
ecco_ctrl_off_grace  = ECCO_pb_ctrl_pp_rg.sel(time=~mask_ctrl_on)

ECCO_adjt_raw_on = ecco_ctrl_on_grace - ecco_v4r5_on_grace
ECCO_adjt_raw_off = ecco_ctrl_off_grace - ecco_v4r5_off_grace

# detrend
ecco_v4r5_on_grace_dt = ECCO_pb_v4r5_pp_rg_dt.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_off_grace_dt = ECCO_pb_v4r5_pp_rg_dt.sel(time=~mask_v4r5_on)

ecco_ctrl_on_grace_dt = ECCO_pb_ctrl_pp_rg_dt.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_ctrl_off_grace_dt = ECCO_pb_ctrl_pp_rg_dt.sel(time=~mask_ctrl_on)

#ecco_v4r5_on_grace_dt = detrend_per_pixel(ecco_v4r5_on_grace)
#ecco_v4r5_off_grace_dt = detrend_per_pixel(ecco_v4r5_off_grace)

#ecco_ctrl_on_grace_dt = detrend_per_pixel(ecco_ctrl_on_grace)
#ecco_ctrl_off_grace_dt = detrend_per_pixel(ecco_ctrl_off_grace)

ECCO_adjt_dt_on = ecco_ctrl_on_grace_dt - ecco_v4r5_on_grace_dt
ECCO_adjt_dt_off = ecco_ctrl_off_grace_dt- ecco_v4r5_off_grace_dt

#ECCO_adjt_dt_on = detrend_per_pixel(ECCO_adjt_raw_on)
#ECCO_adjt_dt_off = detrend_per_pixel(ECCO_adjt_raw_off)

# MSC removed (residuals)

ecco_v4r5_on_grace_sr = ECCO_pb_v4r5_pp_rg_sr.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_off_grace_sr = ECCO_pb_v4r5_pp_rg_sr.sel(time=~mask_v4r5_on)

ecco_ctrl_on_grace_sr = ECCO_pb_ctrl_pp_rg_sr.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_ctrl_off_grace_sr = ECCO_pb_ctrl_pp_rg_sr.sel(time=~mask_ctrl_on)


#ecco_v4r5_on_grace_sr = remove_seasonal_cycle(ecco_v4r5_on_grace_dt)
#ecco_v4r5_off_grace_sr = remove_seasonal_cycle(ecco_v4r5_off_grace_dt)

#ecco_ctrl_on_grace_sr = remove_seasonal_cycle(ecco_ctrl_on_grace_dt)
#ecco_ctrl_off_grace_sr = remove_seasonal_cycle(ecco_ctrl_off_grace_dt)

ECCO_adjt_sr_on = ecco_ctrl_on_grace_sr - ecco_v4r5_on_grace_sr
ECCO_adjt_sr_off = ecco_ctrl_off_grace_sr- ecco_v4r5_off_grace_sr

#ECCO_adjt_sr_on = remove_seasonal_cycle(ECCO_adjt_dt_on)
#ECCO_adjt_sr_off = remove_seasonal_cycle(ECCO_adjt_dt_off)

da = [ECCO_adjt_raw_on, ECCO_adjt_raw_off,
      ECCO_adjt_dt_on, ECCO_adjt_dt_off,
      ECCO_adjt_sr_on, ECCO_adjt_sr_off]

stds = []
temp = []

for i in range(0, len(da), 2):
    std_on  = da[i].std(dim = 'time')
    std_off = da[i+1].std(dim = 'time')
    dif = std_on - std_off
    #temp = [std_on, std_off, dif]
    stds.extend([std_on, std_off, dif])
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap, BoundaryNorm

title_list = [
    "ECCO Adjustments Std. (Raw) - GRACE On", "ECCO Adjustments Std. (Raw) - GRACE Off", "On-Off Difference (Raw)",
    "ECCO Adjustments Std. (Detrended) - GRACE On", "ECCO Adjustments Std. (Detrended) - GRACE Off", "On-Off Difference (Detrended)",
    "ECCO Adjustments Std. (Residual) - GRACE On", "ECCO Adjustments Std. (Residual) - GRACE Off", "On-Off Difference (Residual)"
]
labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)']

data_list = stds  # length 9 list of DataArrays with coords latitude/longitude

# First two columns colormap
n_bins = 20
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])
boundaries = np.arange(0, 5.0 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(cmap_disc.colors))

# Third column colormap
n_bins2 = 20
boundaries2 = np.arange(-1, 1 + 0.25, 0.25)
cmap_disc2 = plt.cm.get_cmap("coolwarm", n_bins2)
norm_disc2 = BoundaryNorm(boundaries2, ncolors=n_bins2)

nrows, ncols = 3, 3
fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols, figsize=(12, 8), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

axes = axes.reshape(nrows, ncols)
pc_leftmid = None   # last QuadMesh from columns 1–2
pc_right   = None   # last QuadMesh from column 3

for i_row in range(nrows):
    for j_col in range(ncols):
        idx = i_row * ncols + j_col
        ax  = axes[i_row, j_col]
        da  = data_list[idx]

        # choose colormap group by column
        if j_col < 2:
            pc = ax.pcolormesh(
                da.longitude, da.latitude, da,
                transform=ccrs.PlateCarree(),
                cmap=cmap_disc, norm=norm_disc
            )
            pc_leftmid = pc
        else:
            pc = ax.pcolormesh(
                da.longitude, da.latitude, da,
                transform=ccrs.PlateCarree(),
                cmap=cmap_disc2, norm=norm_disc2
            )
            pc_right = pc

        ax.set_global()
        ax.add_feature(cfeature.NaturalEarthFeature(
            'physical', 'land', '110m',
            linewidth=0.5, edgecolor='black', facecolor='darkgray'
        ))

        # ==== GRIDLINE LABEL LOGIC (left column = lat labels; bottom row = lon labels) ====
        gl = ax.gridlines(color='gray', linestyle='--', linewidth=0.2, draw_labels=True)
        gl.ylabels_left   = (j_col == 0)
        gl.ylabels_right  = False
        gl.xlabels_bottom = (i_row == nrows - 1)
        gl.xlabels_top    = False
        gl.xlabel_style = {'size': 7}
        gl.ylabel_style = {'size': 7}
        # ================================================================================

        ax.set_title(title_list[idx], fontsize=11)
        ax.text(-0.07, 1., labels[idx], transform=ax.transAxes,
                fontsize=8, fontweight='bold', va='top', ha='left',
                bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Leave space for two bottom colorbars
plt.tight_layout(rect=[0, 0.12, 1, 1])

# Shared horizontal colorbar for columns 1–2
pos_left  = axes[-1, 0].get_position()
pos_mid   = axes[-1, 1].get_position()
left_x    = pos_left.x0
right_x   = pos_mid.x1
width     = right_x - left_x
cax12 = fig.add_axes([left_x, pos_left.y0 - 0.07, width, 0.022])
cb12  = fig.colorbar(pc_leftmid, cax=cax12, orientation='horizontal', ticks=boundaries)
cb12.set_label("Std. (cm)", fontsize=10)
cb12.ax.tick_params(labelsize=8)

# Shared horizontal colorbar for column 3
pos_right = axes[-1, 2].get_position()
cax3 = fig.add_axes([pos_right.x0, pos_right.y0 - 0.07, pos_right.width, 0.022])
cb3  = fig.colorbar(pc_right, cax=cax3, orientation='horizontal', ticks=boundaries2)
cb3.set_label("Difference (cm)", fontsize=10)
cb3.ax.tick_params(labelsize=8)
plt.savefig("ECCO_adjustments_std_GRACE_on_off_updated.png", dpi=500, bbox_inches='tight')
plt.show()


### Accepted Method
1. First detrend the complete time series
2. Then filter grace and non grace dates for ECCO
3. Then conduct computations

In [None]:
ecco_ctrl_on_grace = ECCO_pb_ctrl_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_on_grace = ECCO_pb_v4r5_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))

ECCO_pb_ctrl_pp_rg_dt = detrend_per_pixel(ECCO_pb_ctrl_pp_rg)
ECCO_pb_v4r5_pp_rg_dt = detrend_per_pixel(ECCO_pb_v4r5_pp_rg)

ECCO_pb_ctrl_pp_rg_sr = remove_seasonal_cycle(ECCO_pb_ctrl_pp_rg_dt)
ECCO_pb_v4r5_pp_rg_sr = remove_seasonal_cycle(ECCO_pb_v4r5_pp_rg_dt)

# 1) Build a fast lookup of GRACE months
grace_times = aligned_GRACE_pp_rg.time.values  # datetime64[ns] array

# 2) Boolean masks on ECCO time axes
mask_v4r5_on  = xr.DataArray(np.isin(ECCO_pb_v4r5_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_v4r5_pp_rg.time})
mask_ctrl_on  = xr.DataArray(np.isin(ECCO_pb_ctrl_pp_rg.time.values, grace_times),
                             dims=["time"], coords={"time": ECCO_pb_ctrl_pp_rg.time})

# 3) Subset ECCO to GRACE dates (“on”) and outside GRACE dates (“off”)
# v4r5
#ecco_ctrl_on_grace = ECCO_pb_ctrl_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
#ecco_v4r5_on_grace = ECCO_pb_v4r5_pp_rg.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_on_grace   = ECCO_pb_v4r5_pp_rg.sel(time=mask_v4r5_on)
ecco_v4r5_off_grace  = ECCO_pb_v4r5_pp_rg.sel(time=~mask_v4r5_on)

# ctrl
ecco_ctrl_on_grace   = ECCO_pb_ctrl_pp_rg.sel(time=mask_ctrl_on)
ecco_ctrl_off_grace  = ECCO_pb_ctrl_pp_rg.sel(time=~mask_ctrl_on)

ECCO_adjt_raw_on = ecco_ctrl_on_grace - ecco_v4r5_on_grace
ECCO_adjt_raw_off = ecco_ctrl_off_grace - ecco_v4r5_off_grace

# detrend
#ecco_v4r5_on_grace_dt = ECCO_pb_v4r5_pp_rg_dt.sel(time=slice("2002-04-01", "2019-12-01"))
ecco_v4r5_on_grace_dt = ECCO_pb_v4r5_pp_rg_dt.sel(time=mask_v4r5_on)
ecco_v4r5_off_grace_dt = ECCO_pb_v4r5_pp_rg_dt.sel(time=~mask_v4r5_on)

ecco_ctrl_on_grace_dt = ECCO_pb_ctrl_pp_rg_dt.sel(time=mask_ctrl_on)
ecco_ctrl_off_grace_dt = ECCO_pb_ctrl_pp_rg_dt.sel(time=~mask_ctrl_on)

#ecco_v4r5_on_grace_dt = detrend_per_pixel(ecco_v4r5_on_grace)
#ecco_v4r5_off_grace_dt = detrend_per_pixel(ecco_v4r5_off_grace)

#ecco_ctrl_on_grace_dt = detrend_per_pixel(ecco_ctrl_on_grace)
#ecco_ctrl_off_grace_dt = detrend_per_pixel(ecco_ctrl_off_grace)

ECCO_adjt_dt_on = ecco_ctrl_on_grace_dt - ecco_v4r5_on_grace_dt
ECCO_adjt_dt_off = ecco_ctrl_off_grace_dt- ecco_v4r5_off_grace_dt

#ECCO_adjt_dt_on = detrend_per_pixel(ECCO_adjt_raw_on)
#ECCO_adjt_dt_off = detrend_per_pixel(ECCO_adjt_raw_off)

# MSC removed (residuals)

ecco_v4r5_on_grace_sr = ECCO_pb_v4r5_pp_rg_sr.sel(time=mask_v4r5_on)
ecco_v4r5_off_grace_sr = ECCO_pb_v4r5_pp_rg_sr.sel(time=~mask_v4r5_on)

ecco_ctrl_on_grace_sr = ECCO_pb_ctrl_pp_rg_sr.sel(time=mask_ctrl_on)
ecco_ctrl_off_grace_sr = ECCO_pb_ctrl_pp_rg_sr.sel(time=~mask_ctrl_on)


#ecco_v4r5_on_grace_sr = remove_seasonal_cycle(ecco_v4r5_on_grace_dt)
#ecco_v4r5_off_grace_sr = remove_seasonal_cycle(ecco_v4r5_off_grace_dt)

#ecco_ctrl_on_grace_sr = remove_seasonal_cycle(ecco_ctrl_on_grace_dt)
#ecco_ctrl_off_grace_sr = remove_seasonal_cycle(ecco_ctrl_off_grace_dt)

ECCO_adjt_sr_on = ecco_ctrl_on_grace_sr - ecco_v4r5_on_grace_sr
ECCO_adjt_sr_off = ecco_ctrl_off_grace_sr- ecco_v4r5_off_grace_sr

#ECCO_adjt_sr_on = remove_seasonal_cycle(ECCO_adjt_dt_on)
#ECCO_adjt_sr_off = remove_seasonal_cycle(ECCO_adjt_dt_off)

da = [ECCO_adjt_raw_on, ECCO_adjt_raw_off,
      ECCO_adjt_dt_on, ECCO_adjt_dt_off,
      ECCO_adjt_sr_on, ECCO_adjt_sr_off]

stds = []
temp = []

for i in range(0, len(da), 2):
    std_on  = da[i].std(dim = 'time')
    std_off = da[i+1].std(dim = 'time')
    dif = std_on - std_off
    #temp = [std_on, std_off, dif]
    stds.extend([std_on, std_off, dif])
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap, BoundaryNorm

title_list = [
    "ECCO Adjustments Std. (Raw) - GRACE On", "ECCO Adjustments Std. (Raw) - GRACE Off", "On-Off Difference (Raw)",
    "ECCO Adjustments Std. (Detrended) - GRACE On", "ECCO Adjustments Std. (Detrended) - GRACE Off", "On-Off Difference (Detrended)",
    "ECCO Adjustments Std. (Residual) - GRACE On", "ECCO Adjustments Std. (Residual) - GRACE Off", "On-Off Difference (Residual)"
]
labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)']

data_list = stds  # length 9 list of DataArrays with coords latitude/longitude

# First two columns colormap
n_bins = 20
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])
boundaries = np.arange(0, 5.0 + 0.5, 0.5)
norm_disc = BoundaryNorm(boundaries, ncolors=len(cmap_disc.colors))

# Third column colormap
n_bins2 = 20
boundaries2 = np.arange(-1, 1 + 0.25, 0.25)
cmap_disc2 = plt.cm.get_cmap("coolwarm", n_bins2)
norm_disc2 = BoundaryNorm(boundaries2, ncolors=n_bins2)

nrows, ncols = 3, 3
fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols, figsize=(12, 8), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

axes = axes.reshape(nrows, ncols)
pc_leftmid = None   # last QuadMesh from columns 1–2
pc_right   = None   # last QuadMesh from column 3

for i_row in range(nrows):
    for j_col in range(ncols):
        idx = i_row * ncols + j_col
        ax  = axes[i_row, j_col]
        da  = data_list[idx]

        # choose colormap group by column
        if j_col < 2:
            pc = ax.pcolormesh(
                da.longitude, da.latitude, da,
                transform=ccrs.PlateCarree(),
                cmap=cmap_disc, norm=norm_disc
            )
            pc_leftmid = pc
        else:
            pc = ax.pcolormesh(
                da.longitude, da.latitude, da,
                transform=ccrs.PlateCarree(),
                cmap=cmap_disc2, norm=norm_disc2
            )
            pc_right = pc

        ax.set_global()
        ax.add_feature(cfeature.NaturalEarthFeature(
            'physical', 'land', '110m',
            linewidth=0.5, edgecolor='black', facecolor='darkgray'
        ))

        # ==== GRIDLINE LABEL LOGIC (left column = lat labels; bottom row = lon labels) ====
        gl = ax.gridlines(color='gray', linestyle='--', linewidth=0.2, draw_labels=True)
        gl.ylabels_left   = (j_col == 0)
        gl.ylabels_right  = False
        gl.xlabels_bottom = (i_row == nrows - 1)
        gl.xlabels_top    = False
        gl.xlabel_style = {'size': 7}
        gl.ylabel_style = {'size': 7}
        # ================================================================================

        ax.set_title(title_list[idx], fontsize=11)
        ax.text(-0.07, 1., labels[idx], transform=ax.transAxes,
                fontsize=8, fontweight='bold', va='top', ha='left',
                bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Leave space for two bottom colorbars
plt.tight_layout(rect=[0, 0.12, 1, 1])

# Shared horizontal colorbar for columns 1–2
pos_left  = axes[-1, 0].get_position()
pos_mid   = axes[-1, 1].get_position()
left_x    = pos_left.x0
right_x   = pos_mid.x1
width     = right_x - left_x
cax12 = fig.add_axes([left_x, pos_left.y0 - 0.07, width, 0.022])
cb12  = fig.colorbar(pc_leftmid, cax=cax12, orientation='horizontal', ticks=boundaries)
cb12.set_label("Std. (cm)", fontsize=10)
cb12.ax.tick_params(labelsize=8)

# Shared horizontal colorbar for column 3
pos_right = axes[-1, 2].get_position()
cax3 = fig.add_axes([pos_right.x0, pos_right.y0 - 0.07, pos_right.width, 0.022])
cb3  = fig.colorbar(pc_right, cax=cax3, orientation='horizontal', ticks=boundaries2)
cb3.set_label("Difference (cm)", fontsize=10)
cb3.ax.tick_params(labelsize=8)
plt.savefig("ECCO_adjustments_std_true_GRACE_on_off_updated.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
# Choose number of discrete bins
n_bins = 20   # same as the length of your custom_colors list

# Sample evenly from the "jet" colormap
jet = plt.cm.get_cmap("coolwarm", n_bins)
jet_colors = [jet(i) for i in range(n_bins)]

# Create discrete colormap
cmap_disc = ListedColormap(jet_colors)
boundaries = np.arange(-1, 1 + 0.25, 0.25)
norm_disc = BoundaryNorm(boundaries, ncolors=len(jet_colors))


data = stds[5]

fig = plt.figure(figsize=(11, 6), dpi=300)
ax = plt.subplot(1, 1, 1, projection=ccrs.PlateCarree())   # <- GeoAxes


pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm = norm_disc)
ax.set_global() 
    
    # Gridlines
gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
gl.xlabels_bottom = True
gl.ylabels_left = True

# Add land feature
ax.add_feature(cfeature.NaturalEarthFeature(
    'physical', 'land', '110m',
    linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
ax.set_title("Detrended ECCO Adjustments (GRACE On Std. minus GRACE Off Std)", fontsize=12)


# Colorbar
cbar_ax = fig.add_axes([0.12, 0.05, 0.78, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.2f}' for b in boundaries])
cbar.set_label("(cm)")

plt.savefig("ECCO_adjustments_std_true_GRACE_on_off_detrended.png", dpi=500, bbox_inches='tight')

plt.show()
#plt.tight_layout(rect=[0, 0.1, 1, 1])


#dif.plot()

### GRD Related PVE Calculations

In [None]:
GRD_alinged = xr.open_dataset('/glade/work/netige/Data/GRACE/GRACE_GRD_aligned.nc')

In [None]:
#convert GRACE GRD dates to normal format

time_vals = np.array(GRD_alinged.time)

# Convert decimal years to datetime format
base_year = np.floor(time_vals).astype(int)  # Extract the integer year part

fractional_part = time_vals - base_year      # Get the fractional part of the year

mon = (fractional_part*12)+1
mon = mon.round(0)

# Convert year and month to string before concatenation
year_str = base_year.astype(str)
month_str = np.char.zfill(mon.astype(str), 2)  # Ensure two-digit month formatting

# Proper string concatenation using np.char.add
date_str = np.char.add(np.char.add(year_str, '-'), np.char.add(month_str, '-01'))

# Convert to datetime
dates = pd.to_datetime(date_str)

GRD_alinged = GRD_alinged.assign_coords(time=dates)

In [None]:
# GRACE  GRD put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = GRD_alinged.GRD  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='GRD'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_GRD_pp_rg = regridded_data_da


In [None]:
#GRD_common = aligned_GRD_pp_rg.sel(time = common_time)
GRD_common = aligned_GRD_pp_rg.sel(time = common_times)

### PVE Equation
[1 - var(x - y)/var(x)]*100%

where x is GRACE-v4r5 and y is GRD.  We should try PVE based on full fields and detrended fields. While we are at it, we can do it for both GRACE-v4r5 and GRACE-ctrl fields.

[1 - var{ (GRACE-v4r5) - GRD} / var(GRACE-v4r5) ] *100%

In [None]:
x = (GRACE_common - ECCO_pb_v4r5_common)
y = GRD_common 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

GRD_raw_std = y.std(dim = 'time')

PVE_v4r5_raw = (1-(z_var/ x_var)) *100

In [None]:
x = (GRACE_common - ECCO_pb_ctrl_common)
y = GRD_common 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

PVE_ctrl_raw = (1-(z_var/ x_var)) *100

### Compute PVE after detrending data

In [None]:
grace_common_dt = detrend_per_pixel(GRACE_common)
ECCO_pb_v4r5_common_dt = detrend_per_pixel(ECCO_pb_v4r5_common)
ECCO_pb_ctrl_common_dt = detrend_per_pixel(ECCO_pb_ctrl_common)
GRD_common_dt = detrend_per_pixel(GRD_common)

In [None]:
x = (grace_common_dt - ECCO_pb_v4r5_common_dt)
y = GRD_common_dt 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

GRD_dt_std = y.std(dim = 'time')

PVE_v4r5_dt= (1-(z_var/ x_var)) *100

In [None]:
x = (grace_common_dt - ECCO_pb_ctrl_common_dt)
y = GRD_common_dt 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

PVE_ctrl_dt= (1-(z_var/ x_var)) *100

In [None]:
# basic check plot

#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

n_bins = 10  # 0..100 in steps of 10
boundaries = np.arange(0, 110, 10)  # [0, 10, ..., 100]

# Discrete jet for nonnegative values
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])

# Anything < 0 becomes gray
cmap_disc.set_under("#d9d9d9")   # or "#bdbdbd"

# BoundaryNorm maps data to your discrete bins;
# values < boundaries[0] (i.e., <0) use the "under" color
norm_disc = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)


title_list = ['PVE by GRD (GRACE-v4r5) (Full Fields) ',
              'PVE by GRD (GRACE-ctrl) (Full Fields)',
             'PVE by GRD (GRACE-v4r5) (Detrended Fields) ',
              'PVE by GRD (GRACE-ctrl) (Detrended Fields)']

labels = ['(a)', '(b)', '(c)', '(d)']

data_list = [PVE_v4r5_raw, PVE_ctrl_raw,
            PVE_v4r5_dt, PVE_ctrl_dt]

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 5.5), dpi=500, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 2
    gl.ylabels_left = i % 2 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=10)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.02])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.0f}' for b in boundaries])
cbar.set_label("PVE (%)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("PVE_GRD.png", dpi=500, bbox_inches='tight')

plt.show()


In [None]:
# basic check plot

#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

n_bins = 10  # 0..100 in steps of 10
boundaries = np.arange(0, 110, 10)  # [0, 10, ..., 100]

# Discrete jet for nonnegative values
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])

# Anything < 0 becomes gray
cmap_disc.set_under("#d9d9d9")   # or "#bdbdbd"

# BoundaryNorm maps data to your discrete bins;
# values < boundaries[0] (i.e., <0) use the "under" color
norm_disc = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)


title_list = ['PVE by GRD (GRACE-v4r5) (Full Fields) ',
              'PVE by GRD (GRACE-ctrl) (Full Fields)',
             'PVE by GRD (GRACE-v4r5) (Detrended Fields) ',
              'PVE by GRD (GRACE-ctrl) (Detrended Fields)']

labels = ['(a)', '(b)', '(c)', '(d)']

data_list = [GRD_raw_std,
             GRD_dt_std,
             PVE_v4r5_raw,
            PVE_v4r5_dt]

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 5.5), dpi=500, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 2
    gl.ylabels_left = i % 2 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=10)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.02])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.0f}' for b in boundaries])
cbar.set_label("PVE (%)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
#plt.savefig("PVE_GRD.png", dpi=500, bbox_inches='tight')

plt.show()


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm, Normalize

# --- colormap (discrete jet for >=0; gray for <0) ---
n_colors = 10
jet = plt.cm.get_cmap("jet", n_colors)
cmap_disc = ListedColormap([jet(i) for i in range(n_colors)])
cmap_disc.set_under("#d9d9d9")

# per-panel limits
vlims_list = [(0, 5), (0, 0.7), (0, 100), (0, 100)]

title_list = [
    'GRD Std.',
    'GRD std. (no trend)',
    'PVE by GRD (GRACE-v4r5) (Full Fields)',
    'PVE by GRD (GRACE-v4r5) (Detrended Fields)'
]
labels = ['(a)', '(b)', '(c)', '(d)']
data_list = [GRD_raw_std, GRD_dt_std, PVE_v4r5_raw, PVE_v4r5_dt]

fig, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(9, 6.5), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

for i, (ax, data, title, (vmin, vmax)) in enumerate(zip(axes.flat, data_list, title_list, vlims_list)):

    # --- build per-panel norm/ticks safely ---
    if np.isclose(vmax, vmin):
        # Degenerate range: use continuous Normalize and simple ticks
        norm = Normalize(vmin=vmin, vmax=vmax if vmax > vmin else vmin + 1e-6)
        ticks = np.linspace(vmin, vmax if vmax > vmin else vmin + 1e-6, 3)
    else:
        nbins = n_colors  # choose how many discrete bins you want per panel
        boundaries = np.linspace(vmin, vmax, nbins + 1, dtype=float)
        norm = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)
        # choose a manageable number of cbar ticks
        nticks = min(6, nbins + 1)
        ticks = np.linspace(vmin, vmax, nticks)

    # --- plot ---
    pc = ax.pcolormesh(
        data.longitude, data.latitude, data,
        transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm
    )

    ax.set_global()

    # Gridlines: bottom row only; left column only
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01, draw_labels=True)
    if hasattr(gl, "xlabels_bottom"): gl.xlabels_bottom = (i >= 2)
    if hasattr(gl, "ylabels_left"):   gl.ylabels_left   = (i % 2 == 0)
    if hasattr(gl, "xlabels_top"):    gl.xlabels_top    = False
    if hasattr(gl, "ylabels_right"):  gl.ylabels_right  = False
    if hasattr(gl, "bottom_labels"):  gl.bottom_labels  = (i >= 2)
    if hasattr(gl, "left_labels"):    gl.left_labels    = (i % 2 == 0)
    if hasattr(gl, "top_labels"):     gl.top_labels     = False
    if hasattr(gl, "right_labels"):   gl.right_labels   = False

    gl.xlabel_style = {'size': 7}      # or {'fontsize': 7}
    gl.ylabel_style = {'size': 7}
    
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m', linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))

    ax.set_title(title, fontsize=10)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # per-panel colorbar (bottom)
    cax = ax.inset_axes([0.08, -0.20, 0.84, 0.06])
    if i in (0, 1):
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks)
    else:
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks, extend="min")
    cb.ax.tick_params(labelsize=7)
    #cb.ax.tick_params(labelsize=7)
    if i in (0, 1):
        cb.set_label("cm", fontsize=8)
    else:
        cb.set_label("PVE (%)", fontsize=8)

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig("PVE_GRD_2.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
# Compute trend of GRD
GRD_common = aligned_GRD_pp_rg.sel(time = common_times)
GRD_common_trend     = compute_trends_manually(GRD_common)
GRD_common_trend_pyr = GRD_common_trend * 365

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm, Normalize

# --- colormap (discrete jet for >=0; gray for <0) ---
n_colors = 10
jet = plt.cm.get_cmap("jet", n_colors)
cmap_disc = ListedColormap([jet(i) for i in range(n_colors)])
cmap_disc.set_under("#d9d9d9")

n_colors2 = 10
cw = plt.cm.get_cmap("coolwarm", n_colors)
cmap_disc2 = ListedColormap([cw(i) for i in range(n_colors2)])
#cmap_disc.set_under("#d9d9d9")

# per-panel limits
vlims_list = [(-0.7, 0.7), (0, 1), (0, 100), (0, 100)]

title_list = [
    'GRD Trend',
    'GRD std. (no trend)',
    'PVE by GRD (GRACE-v4r5) (Full Fields)',
    'PVE by GRD (GRACE-v4r5) (Detrended Fields)'
]
labels = ['(a)', '(b)', '(c)', '(d)']
data_list = [GRD_common_trend_pyr, GRD_dt_std, PVE_v4r5_raw, PVE_v4r5_dt]

fig, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(9, 6.5), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

for i, (ax, data, title, (vmin, vmax)) in enumerate(zip(axes.flat, data_list, title_list, vlims_list)):

    # --- build per-panel norm/ticks safely ---
    if np.isclose(vmax, vmin):
        # Degenerate range: use continuous Normalize and simple ticks
        norm = Normalize(vmin=vmin, vmax=vmax if vmax > vmin else vmin + 1e-6)
        ticks = np.linspace(vmin, vmax if vmax > vmin else vmin + 1e-6, 3)
    else:
        nbins = n_colors  # choose how many discrete bins you want per panel
        boundaries = np.linspace(vmin, vmax, nbins + 1, dtype=float)
        norm = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)
        # choose a manageable number of cbar ticks
        nticks = min(6, nbins + 1)
        ticks = np.linspace(vmin, vmax, nticks)

    # --- plot ---
    if i == (0):
        pc = ax.pcolormesh(
        GRD_common.longitude, GRD_common.latitude, GRD_common_trend_pyr,
        transform=ccrs.PlateCarree(), cmap=cmap_disc2, norm=norm
    )
    else:
        pc = ax.pcolormesh(
            data.longitude, data.latitude, data,
            transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm
    )

    ax.set_global()

    # Gridlines: bottom row only; left column only
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01, draw_labels=True)
    if hasattr(gl, "xlabels_bottom"): gl.xlabels_bottom = (i >= 2)
    if hasattr(gl, "ylabels_left"):   gl.ylabels_left   = (i % 2 == 0)
    if hasattr(gl, "xlabels_top"):    gl.xlabels_top    = False
    if hasattr(gl, "ylabels_right"):  gl.ylabels_right  = False
    if hasattr(gl, "bottom_labels"):  gl.bottom_labels  = (i >= 2)
    if hasattr(gl, "left_labels"):    gl.left_labels    = (i % 2 == 0)
    if hasattr(gl, "top_labels"):     gl.top_labels     = False
    if hasattr(gl, "right_labels"):   gl.right_labels   = False

    gl.xlabel_style = {'size': 7}      # or {'fontsize': 7}
    gl.ylabel_style = {'size': 7}
    
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m', linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))

    ax.set_title(title, fontsize=10)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # per-panel colorbar (bottom)
    cax = ax.inset_axes([0.08, -0.20, 0.84, 0.06])
    if i in (0, 1):
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks)
    else:
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks, extend="min")
    cb.ax.tick_params(labelsize=7)
    #cb.ax.tick_params(labelsize=7)
    if i == (0):
        cb.set_label("cm/year", fontsize=8)
    if i == (1):
        cb.set_label("cm", fontsize=8)
    if i in (2,3):
        cb.set_label("PVE (%)", fontsize=8)

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig("PVE_GRD_3.png", dpi=500, bbox_inches='tight')
plt.show()


### UTA CSR GRD EQ Correction Analysis

In [None]:
CSR_GRD_aligned = xr.open_dataset('/glade/work/netige/Data/GRACE/UTA_CSR_GRD_aligned.nc')
CSR_EQ_aligned = xr.open_dataset('/glade/work/netige/Data/GRACE/UTA_CSR_EQ_aligned.nc')


In [None]:
#convert GRACE GRD dates to normal format

time_vals = np.array(CSR_GRD_aligned.time)

# Convert decimal years to datetime format
base_year = np.floor(time_vals).astype(int)  # Extract the integer year part

fractional_part = time_vals - base_year      # Get the fractional part of the year

mon = (fractional_part*12)+1
mon = mon.round(0)

# Convert year and month to string before concatenation
year_str = base_year.astype(str)
month_str = np.char.zfill(mon.astype(str), 2)  # Ensure two-digit month formatting

# Proper string concatenation using np.char.add
date_str = np.char.add(np.char.add(year_str, '-'), np.char.add(month_str, '-01'))

# Convert to datetime
dates = pd.to_datetime(date_str)

CSR_GRD_aligned = CSR_GRD_aligned.assign_coords(time=dates)
CSR_EQ_aligned = CSR_EQ_aligned.assign_coords(time=dates)


In [None]:
# GRACE  GRD put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = CSR_GRD_aligned.GRD  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='GRD'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_CSR_GRD_pp_rg = regridded_data_da
del CSR_GRD_aligned

In [None]:
# GRACE  GRD put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = CSR_EQ_aligned.EQ  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='GRD'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_CSR_EQ_pp_rg = regridded_data_da
del CSR_EQ_aligned

In [None]:
common_times2 = np.intersect1d(aligned_CSR_GRD_pp_rg.time.values,
                             aligned_GRACE_pp_rg.time.values)

CSR_GRD_common = aligned_CSR_GRD_pp_rg.sel(time = common_times2)
CSR_EQ_common = aligned_CSR_EQ_pp_rg.sel(time = common_times2)
CSR_GRACE_common = aligned_GRACE_pp_rg.sel(time = common_times2)
CRS_ECCO_v4r5_common = ECCO_pb_v4r5_pp_rg.sel(time = common_times2)
CRS_ECCO_ctrl_common = ECCO_pb_ctrl_pp_rg.sel(time = common_times2)

In [None]:
CSR_GRD_common_sd = CSR_GRD_common.std(dim = 'time')
CSR_GRD_common_dt = detrend_per_pixel(CSR_GRD_common)
CSR_GRD_common_dt_sd = CSR_GRD_common_dt.std(dim = 'time')
CSR_EQ_common_sd = CSR_EQ_common.std(dim = 'time')
CSR_GRD_common_trend = compute_trends_manually(CSR_GRD_common)
CSR_GRD_common_trend_pyr = CSR_GRD_common_trend*365
CRS_ECCO_v4r5_common_dt = detrend_per_pixel(CRS_ECCO_v4r5_common)
CRS_ECCO_ctrl_common_dt = detrend_per_pixel(CRS_ECCO_ctrl_common)
CSR_GRACE_common_dt = detrend_per_pixel(CSR_GRACE_common)
CSR_EQ_common_dt = detrend_per_pixel(CSR_EQ_common)

### Plot the std of EQ and GRD for CSR products

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm, Normalize

# --- colormap (discrete jet for >=0; gray for <0) ---
n_colors = 10
jet = plt.cm.get_cmap("jet", n_colors)
cmap_disc = ListedColormap([jet(i) for i in range(n_colors)])
#cmap_disc.set_under("#d9d9d9")

n_colors2 = 10
cw = plt.cm.get_cmap("coolwarm", n_colors)
cmap_disc2 = ListedColormap([cw(i) for i in range(n_colors2)])
#cmap_disc.set_under("#d9d9d9")

# per-panel limits
vlims_list = [(-0.7, 0.7), (0, 5), (0, 1), (0, 2)]

title_list = [
    'CSR GRD Trend',
    'CSR GRD std. (with trend)',
    'CSR GRD std. (no trend)',
    'CSR EQ std.'
]
labels = ['(a)', '(b)', '(c)', '(d)']
data_list = [CSR_GRD_common_trend_pyr, CSR_GRD_common_sd, CSR_GRD_common_dt_sd, CSR_EQ_common_sd]

fig, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(9, 6.5), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

for i, (ax, data, title, (vmin, vmax)) in enumerate(zip(axes.flat, data_list, title_list, vlims_list)):

    # --- build per-panel norm/ticks safely ---
    if np.isclose(vmax, vmin):
        # Degenerate range: use continuous Normalize and simple ticks
        norm = Normalize(vmin=vmin, vmax=vmax if vmax > vmin else vmin + 1e-6)
        ticks = np.linspace(vmin, vmax if vmax > vmin else vmin + 1e-6, 3)
    else:
        nbins = n_colors  # choose how many discrete bins you want per panel
        boundaries = np.linspace(vmin, vmax, nbins + 1, dtype=float)
        norm = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)
        # choose a manageable number of cbar ticks
        nticks = min(6, nbins + 1)
        ticks = np.linspace(vmin, vmax, nticks)

    # --- plot ---
    if i == (0):
        pc = ax.pcolormesh(
        CSR_GRD_common.longitude, CSR_GRD_common.latitude, CSR_GRD_common_trend_pyr,
        transform=ccrs.PlateCarree(), cmap=cmap_disc2, norm=norm
    )
    else:
        pc = ax.pcolormesh(
            data.longitude, data.latitude, data,
            transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm
    )

    ax.set_global()

    # Gridlines: bottom row only; left column only
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01, draw_labels=True)
    if hasattr(gl, "xlabels_bottom"): gl.xlabels_bottom = (i >= 2)
    if hasattr(gl, "ylabels_left"):   gl.ylabels_left   = (i % 2 == 0)
    if hasattr(gl, "xlabels_top"):    gl.xlabels_top    = False
    if hasattr(gl, "ylabels_right"):  gl.ylabels_right  = False
    if hasattr(gl, "bottom_labels"):  gl.bottom_labels  = (i >= 2)
    if hasattr(gl, "left_labels"):    gl.left_labels    = (i % 2 == 0)
    if hasattr(gl, "top_labels"):     gl.top_labels     = False
    if hasattr(gl, "right_labels"):   gl.right_labels   = False

    gl.xlabel_style = {'size': 7}      # or {'fontsize': 7}
    gl.ylabel_style = {'size': 7}
    
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m', linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))

    ax.set_title(title, fontsize=10)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # per-panel colorbar (bottom)
    cax = ax.inset_axes([0.08, -0.20, 0.84, 0.06])
    if i in (0, 1):
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks)
    else:
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks)
    cb.ax.tick_params(labelsize=7)
    #cb.ax.tick_params(labelsize=7)
    if i == (0):
        cb.set_label("cm/year", fontsize=8)
    if i == (1):
        cb.set_label("cm", fontsize=8)
    if i in (2,3):
        cb.set_label("cm (%)", fontsize=8)

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig("CSR_GRD_Trend_Std_EQ_std.png", dpi=500, bbox_inches='tight')
plt.show()


### Check the influence of CSR GRD on GRACE/ECCO misfits

In [None]:
x = (CSR_GRACE_common - CRS_ECCO_v4r5_common)
y = CSR_GRD_common 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

GRD_raw_std = y.std(dim = 'time')

PVE_v4r5_raw = (1-(z_var/ x_var)) *100

In [None]:
x = (CSR_GRACE_common - CRS_ECCO_ctrl_common)
y = CSR_GRD_common 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

PVE_ctrl_raw = (1-(z_var/ x_var)) *100

In [None]:
x = (CSR_GRACE_common_dt - CRS_ECCO_v4r5_common_dt)
y = CSR_GRD_common_dt 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

GRD_dt_std = y.std(dim = 'time')

PVE_v4r5_dt= (1-(z_var/ x_var)) *100

In [None]:
x = (CSR_GRACE_common_dt - CRS_ECCO_ctrl_common_dt)
y = CSR_GRD_common_dt 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

PVE_ctrl_dt= (1-(z_var/ x_var)) *100

In [None]:
# basic check plot

#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

n_bins = 10  # 0..100 in steps of 10
boundaries = np.arange(0, 110, 10)  # [0, 10, ..., 100]

# Discrete jet for nonnegative values
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])

# Anything < 0 becomes gray
cmap_disc.set_under("#d9d9d9")   # or "#bdbdbd"

# BoundaryNorm maps data to your discrete bins;
# values < boundaries[0] (i.e., <0) use the "under" color
norm_disc = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)


title_list = ['PVE by CSR GRD (GRACE-v4r5) (Full Fields) ',
              'PVE by CSR GRD (GRACE-ctrl) (Full Fields)',
             'PVE by CSR GRD (GRACE-v4r5) (Detrended Fields) ',
              'PVE by CSR GRD (GRACE-ctrl) (Detrended Fields)']

labels = ['(a)', '(b)', '(c)', '(d)']

data_list = [PVE_v4r5_raw, PVE_ctrl_raw,
            PVE_v4r5_dt, PVE_ctrl_dt]

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 5.5), dpi=500, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 2
    gl.ylabels_left = i % 2 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=10)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.02])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.0f}' for b in boundaries])
cbar.set_label("PVE (%)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("PVE_CSR_GRD.png", dpi=500, bbox_inches='tight')

plt.show()


### Check the influence of CSR EQ on GRACE/ECCO misfits

In [None]:
x = (CSR_GRACE_common - CRS_ECCO_v4r5_common)
y = CSR_EQ_common 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

GRD_raw_std = y.std(dim = 'time')

PVE_v4r5_raw = (1-(z_var/ x_var)) *100

In [None]:
x = (CSR_GRACE_common - CRS_ECCO_ctrl_common)
y = CSR_EQ_common 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

PVE_ctrl_raw = (1-(z_var/ x_var)) *100

In [None]:
x = (CSR_GRACE_common_dt - CRS_ECCO_v4r5_common_dt)
y = CSR_EQ_common_dt 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

GRD_dt_std = y.std(dim = 'time')

PVE_v4r5_dt= (1-(z_var/ x_var)) *100

In [None]:
x = (CSR_GRACE_common_dt - CRS_ECCO_ctrl_common_dt)
y = CSR_EQ_common_dt 
z = x-y

z_var = z.var(dim = 'time')
x_var = x.var(dim = 'time')

PVE_ctrl_dt= (1-(z_var/ x_var)) *100

In [None]:
# basic check plot

#same as above plot but different colorbar

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

n_bins = 10  # 0..100 in steps of 10
boundaries = np.arange(0, 110, 10)  # [0, 10, ..., 100]

# Discrete jet for nonnegative values
jet = plt.cm.get_cmap("jet", n_bins)
cmap_disc = ListedColormap([jet(i) for i in range(n_bins)])

# Anything < 0 becomes gray
cmap_disc.set_under("#d9d9d9")   # or "#bdbdbd"

# BoundaryNorm maps data to your discrete bins;
# values < boundaries[0] (i.e., <0) use the "under" color
norm_disc = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)


title_list = ['PVE by EQ (GRACE-v4r5) (Full Fields) ',
              'PVE by EQ (GRACE-ctrl) (Full Fields)',
             'PVE by EQ (GRACE-v4r5) (Detrended Fields) ',
              'PVE by EQ (GRACE-ctrl) (Detrended Fields)']

labels = ['(a)', '(b)', '(c)', '(d)']

data_list = [PVE_v4r5_raw, PVE_ctrl_raw,
            PVE_v4r5_dt, PVE_ctrl_dt]

# Custom 6-color palette
#custom_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]
#cmap = ListedColormap(custom_colors)

# Set boundaries for 6 intervals: [0.0, 0.5, 1.0, ..., 3.0]
#boundaries = np.arange(0, 3.0 + 0.5, 0.5)
#norm = BoundaryNorm(boundaries, ncolors=len(custom_colors))

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 5.5), dpi=500, 
                         subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(data.longitude, 
                       data.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap=cmap_disc,
                       norm=norm_disc)
    
    ax.set_global() 
    
    # Gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    gl.xlabels_bottom = i >= 2
    gl.ylabels_left = i % 2 == 0
    gl.right_labels = False
    gl.top_labels = False

    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m',
        linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))
    
    ax.set_title(title, fontsize=10)

    # Subplot labels (a)-(f)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

# Colorbar
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.02])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal', ticks=boundaries)
cbar.ax.set_xticklabels([f'{b:.0f}' for b in boundaries])
cbar.set_label("PVE (%)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.savefig("PVE_CSR_EQ.png", dpi=500, bbox_inches='tight')

plt.show()


In [None]:
# Japan: 120°E–150°E, 20–50°N
CSR_EQ_japan  = CSR_EQ_common_sd.sel(longitude=slice(120, 170),
                                     latitude=slice(25, 55))
PVE_japan     = PVE_v4r5_raw.sel(   longitude=slice(120, 170),
                                     latitude=slice(25, 55))

# Indian Ocean: 30°E–120°E, 40°S–25°N
CSR_EQ_indian = CSR_EQ_common_sd.sel(longitude=slice(70, 140),
                                     latitude=slice(-15, 25))
PVE_indian    = PVE_v4r5_raw.sel(   longitude=slice(70, 140),
                                     latitude=slice(-15, 25))


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm, Normalize

# --- colormap (discrete jet for >=0; gray for <0) ---
n_colors = 30
jet = plt.cm.get_cmap("jet", n_colors)
cmap_disc = ListedColormap([jet(i) for i in range(n_colors)])
cmap_disc.set_under("#d9d9d9")

# per-panel limits
vlims_list = [(0, 20), (0, 20), (0, 100), (0, 100)]

title_list = [
    'CSR EQ Std.',
    'CSR EQ Std.',
    'PVE by EQ (GRACE-v4r5) (Full Fields)',
    'PVE by EQ (GRACE-v4r5) (Full Fields)'
]
labels = ['(a)', '(b)', '(c)', '(d)']
data_list = [CSR_EQ_japan, CSR_EQ_indian, PVE_japan, PVE_indian]

fig, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(9, 8), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

for i, (ax, data, title, (vmin, vmax)) in enumerate(zip(axes.flat, data_list, title_list, vlims_list)):

    # --- build per-panel norm/ticks safely ---
    if np.isclose(vmax, vmin):
        # Degenerate range: use continuous Normalize and simple ticks
        norm = Normalize(vmin=vmin, vmax=vmax if vmax > vmin else vmin + 1e-6)
        ticks = np.linspace(vmin, vmax if vmax > vmin else vmin + 1e-6, 3)
    else:
        nbins = n_colors  # choose how many discrete bins you want per panel
        boundaries = np.linspace(vmin, vmax, nbins + 1, dtype=float)
        norm = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)
        # choose a manageable number of cbar ticks
        nticks = min(6, nbins + 1)
        ticks = np.linspace(vmin, vmax, nticks)

    # --- plot ---
    pc = ax.pcolormesh(
        data.longitude, data.latitude, data,
        transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm
    )

    #ax.set_global()

    # Gridlines: bottom row only; left column only
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01, draw_labels=True)
    if hasattr(gl, "xlabels_bottom"): gl.xlabels_bottom = (i >= 2)
    if hasattr(gl, "ylabels_left"):   gl.ylabels_left   = True
    if hasattr(gl, "xlabels_top"):    gl.xlabels_top    = False
    if hasattr(gl, "ylabels_right"):  gl.ylabels_right  = False
    if hasattr(gl, "bottom_labels"):  gl.bottom_labels  = (i >= 2)
    if hasattr(gl, "left_labels"):    gl.left_labels    = True
    if hasattr(gl, "top_labels"):     gl.top_labels     = False
    if hasattr(gl, "right_labels"):   gl.right_labels   = False

    gl.xlabel_style = {'size': 7}      # or {'fontsize': 7}
    gl.ylabel_style = {'size': 7}
    
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m', linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))

    ax.set_title(title, fontsize=10)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # per-panel colorbar (bottom)
    cax = ax.inset_axes([0.08, -0.20, 0.84, 0.06])
    if i in (0, 1):
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks)
    else:
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks, extend="min")
    cb.ax.tick_params(labelsize=7)
    #cb.ax.tick_params(labelsize=7)
    if i in (0, 1):
        cb.set_label("cm", fontsize=8)
    else:
        cb.set_label("PVE (%)", fontsize=8)

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig("PVE_EQ_regional.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
# Japan: 120°E–150°E, 20–50°N
CSR_EQ_eh  = CSR_EQ_common_sd.sel(longitude=slice(70, 170),
                                     latitude=slice(-15, 55))
PVE_eh     = PVE_v4r5_raw.sel(   longitude=slice(70, 170),
                                     latitude=slice(-15, 55))

# Indian Ocean: 30°E–120°E, 40°S–25°N
#CSR_EQ_indian = CSR_EQ_common_sd.sel(longitude=slice(70, 140),
#                                     latitude=slice(-15, 25))
#PVE_indian    = PVE_v4r5_raw.sel(   longitude=slice(70, 140),
#                                     latitude=slice(-15, 25))


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm, Normalize

# --- colormap (discrete jet for >=0; gray for <0) ---
n_colors = 30
jet = plt.cm.get_cmap("jet", n_colors)
cmap_disc = ListedColormap([jet(i) for i in range(n_colors)])
cmap_disc.set_under("#d9d9d9")

# per-panel limits
vlims_list = [(0, 20), (0, 100), (0, 100), (0, 100)]

title_list = [
    'CSR EQ Std.',
    'PVE by EQ (GRACE-v4r5) (Full Fields)'
]
labels = ['(a)', '(b)']
data_list = [CSR_EQ_eh, PVE_eh]

fig, axes = plt.subplots(
    nrows=1, ncols=2, figsize=(9, 8), dpi=500,
    subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)}
)

for i, (ax, data, title, (vmin, vmax)) in enumerate(zip(axes.flat, data_list, title_list, vlims_list)):

    # --- build per-panel norm/ticks safely ---
    if np.isclose(vmax, vmin):
        # Degenerate range: use continuous Normalize and simple ticks
        norm = Normalize(vmin=vmin, vmax=vmax if vmax > vmin else vmin + 1e-6)
        ticks = np.linspace(vmin, vmax if vmax > vmin else vmin + 1e-6, 3)
    else:
        nbins = n_colors  # choose how many discrete bins you want per panel
        boundaries = np.linspace(vmin, vmax, nbins + 1, dtype=float)
        norm = BoundaryNorm(boundaries, ncolors=cmap_disc.N, clip=False)
        # choose a manageable number of cbar ticks
        nticks = min(6, nbins + 1)
        ticks = np.linspace(vmin, vmax, nticks)

    # --- plot ---
    pc = ax.pcolormesh(
        data.longitude, data.latitude, data,
        transform=ccrs.PlateCarree(), cmap=cmap_disc, norm=norm
    )

    #ax.set_global()

    # Gridlines: bottom row only; left column only
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01, draw_labels=True)
    if hasattr(gl, "xlabels_bottom"): gl.xlabels_bottom = True
    if hasattr(gl, "ylabels_left"):   gl.ylabels_left   = (i == 0)
    if hasattr(gl, "xlabels_top"):    gl.xlabels_top    = False
    if hasattr(gl, "ylabels_right"):  gl.ylabels_right  = False
    if hasattr(gl, "bottom_labels"):  gl.bottom_labels  = True
    if hasattr(gl, "left_labels"):    gl.left_labels    = (i == 0)
    if hasattr(gl, "top_labels"):     gl.top_labels     = False
    if hasattr(gl, "right_labels"):   gl.right_labels   = False

    gl.xlabel_style = {'size': 7}      # or {'fontsize': 7}
    gl.ylabel_style = {'size': 7}
    
    ax.add_feature(cfeature.NaturalEarthFeature(
        'physical', 'land', '110m', linewidth=0.5, edgecolor='black', facecolor='darkgray'
    ))

    ax.set_title(title, fontsize=10)
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))

    # per-panel colorbar (bottom)
    cax = ax.inset_axes([0.08, -0.20, 0.84, 0.06])
    if i in (0, 1):
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks)
    else:
        cb = fig.colorbar(pc, cax=cax, orientation='horizontal', ticks=ticks, extend="min")
    cb.ax.tick_params(labelsize=7)
    #cb.ax.tick_params(labelsize=7)
    if i == 0:
        cb.set_label("cm", fontsize=8)
    else:
        cb.set_label("PVE (%)", fontsize=8)

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig("PVE_EQ_regional2.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter

# Your DataArray: (time, latitude, longitude), lon in [-180, 180]
da = aligned_CSR_EQ_pp_rg

# Epicenters (approx.)
sumatra = dict(lat=3.316,  lon=95.854,  label="~2004 Sumatra–Andaman Epicenter")
tohoku  = dict(lat=35.297, lon=142.372, label="~2011 Tōhoku Epicenter")

# OPTIONAL: small spatial average around the epicenter instead of nearest cell
use_box = False
box_deg = 0.5  # +/- degrees for averaging if use_box=True

def extract_ts(da, lat, lon, use_box=False, box=0.5):
    if use_box:
        return da.sel(latitude=slice(lat-box, lat+box),
                      longitude=slice(lon-box, lon+box)).mean(("latitude","longitude"))
    else:
        return da.sel(latitude=lat, longitude=lon, method="nearest")

ts_sumatra = extract_ts(da, sumatra["lat"], sumatra["lon"], use_box, box_deg)
ts_tohoku  = extract_ts(da, tohoku["lat"],  tohoku["lon"],  use_box, box_deg)

# Plot
fig, ax = plt.subplots(figsize=(8, 3.5), dpi=300)
ax.plot(ts_sumatra["time"].values, ts_sumatra.values, lw=1.2, label=sumatra["label"])
ax.plot(ts_tohoku["time"].values,  ts_tohoku.values,  lw=1.2, label=tohoku["label"])

# Event markers
ax.axvline(np.datetime64("2004-12-26"), color="tab:red",   ls="--", lw=1, alpha=0.8)
ax.axvline(np.datetime64("2011-03-11"), color="tab:purple", ls="--", lw=1, alpha=0.8)

ax.set_xlabel("Time")
ax.set_ylabel("CSR EQ Correction (cm)")
#ax.set_title("Time series near epicenters: 2004 Sumatra & 2011 Tōhoku")
ax.legend(frameon=True)
ax.grid(True, ls="--", lw=0.4, alpha=0.5)
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m"))

plt.tight_layout()
plt.savefig("EQ_TS.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
CSR_GRACE_common_cor = CSR_GRACE_common - CSR_EQ_common
grace_min_v4r5 = CSR_GRACE_common_cor - CRS_ECCO_v4r5_common
grace_min_ctrl = CSR_GRACE_common_cor - CRS_ECCO_ctrl_common

In [None]:
%%time

# Write a for loop to compute all interested EOFs.
ds = [CSR_GRACE_common_cor, grace_min_v4r5, grace_min_ctrl]

titles = ["GRACE Raw (no EQ) EOFs",
          "GRACE - v4r5 (no EQ) EOFs", "GRACE - ctrl (no EQ) EOFs"]

solvers = []
eofs =[]
pcs = []
variance_fractions = []
pcs_norm = []

for i in range(len(ds)):

    temp = ds[i]
    if temp.dims[0] != 'time':
        temp = temp.transpose('time', 'latitude', 'longitude')
    
    temp = temp.reset_coords(drop=True)  # keep only dimension coords
    solver = Eof(temp.fillna(0))
    #solvers.append(solver)

    eof = solver.eofsAsCorrelation(neofs=5)  # Leading EOF patterns
    pc = solver.pcs(npcs=5, pcscaling=1)     # Corresponding PCs
    variance_fraction = solver.varianceFraction(neigs=40)  # Variance explained

    pc_norm = pc / pc.std(dim='time')

    eofs.append(eof)
    pcs.append(pc)
    variance_fractions.append(variance_fraction)
    pcs_norm.append(pc_norm)

    


In [None]:
for i in range(len(ds)):
    name = titles[i]
    plot_eofs_with_pcs(
        eofs=eofs[i],                         # your EOF maps (with lat/lon coords)
        pcs=pcs_norm[i],                # 2D (time x modes)
        title=name,
        n_modes=5,
        variance_fraction=variance_fractions[i],  # optional
        save_path = f"./eof_figs/{name}.png"
    )
    




In [None]:
import numpy as np

A = GRACE_common.time.to_index()      # or ds_A.time.to_index()
B = CSR_GRACE_common.time.to_index()

missing_in_A = B.difference(A)   # present in B, missing from A
missing_in_B = A.difference(B)   # present in A, missing from B

print("Missing in JPL GRACE:", missing_in_A.tolist())
print("Missing in CSR:", missing_in_B.tolist())


### Exploring GRAVIS EQ Correction

In [None]:
GRAVIS_GRACE_aligned = xr.open_dataset('/glade/work/netige/Data/GRACE/pb_grace_gravis_aligned.nc')

In [None]:
#convert GRACE GRD dates to normal format

time_vals = np.array(GRAVIS_GRACE_aligned.time)

# Convert decimal years to datetime format
base_year = np.floor(time_vals).astype(int)  # Extract the integer year part

fractional_part = time_vals - base_year      # Get the fractional part of the year

mon = (fractional_part*12)+1
mon = mon.round(0)

# Convert year and month to string before concatenation
year_str = base_year.astype(str)
month_str = np.char.zfill(mon.astype(str), 2)  # Ensure two-digit month formatting

# Proper string concatenation using np.char.add
date_str = np.char.add(np.char.add(year_str, '-'), np.char.add(month_str, '-01'))

# Convert to datetime
dates = pd.to_datetime(date_str)

GRAVIS_GRACE_aligned = GRAVIS_GRACE_aligned.assign_coords(time=dates)



In [None]:
# GRACE  GRD put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = GRAVIS_GRACE_aligned.pb  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='GRD'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_GRAVIS_GRACE_pp_rg = regridded_data_da
del GRAVIS_GRACE_aligned

In [None]:
common_times3 = np.intersect1d(aligned_GRAVIS_GRACE_pp_rg.time.values,
                             ECCO_pb_v4r5_pp_rg.time.values)


GRAVIS_GRACE_common = aligned_GRAVIS_GRACE_pp_rg.sel(time = common_times3)
GRAVIS_ECCO_v4r5_common = ECCO_pb_v4r5_pp_rg.sel(time = common_times3)
GRAVIS_ECCO_ctrl_common = ECCO_pb_ctrl_pp_rg.sel(time = common_times3)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter

# Your DataArray: (time, latitude, longitude), lon in [-180, 180]
da_JPL = aligned_GRACE_pp_rg
da_GRAVIS = aligned_GRAVIS_GRACE_pp_rg

# Epicenters (approx.)
sumatra = dict(lat=3.316,  lon=95.854,  label="~2004 Sumatra–Andaman Epicenter")
tohoku  = dict(lat=35.297, lon=142.372, label="~2011 Tōhoku Epicenter")

# OPTIONAL: small spatial average around the epicenter instead of nearest cell
use_box = False
box_deg = 0.5  # +/- degrees for averaging if use_box=True

def extract_ts(da, lat, lon, use_box=False, box=0.5):
    if use_box:
        return da.sel(latitude=slice(lat-box, lat+box),
                      longitude=slice(lon-box, lon+box)).mean(("latitude","longitude"))
    else:
        return da.sel(latitude=lat, longitude=lon, method="nearest")

ts_JPL_sumatra = extract_ts(da_JPL, sumatra["lat"], sumatra["lon"], use_box, box_deg)
ts_GRAVIS_sumatra = extract_ts(da_GRAVIS, sumatra["lat"], sumatra["lon"], use_box, box_deg)

ts_JPL_tohoku = extract_ts(da_JPL, tohoku["lat"], tohoku["lon"], use_box, box_deg)
ts_GRAVIS_tohoku = extract_ts(da_GRAVIS, tohoku["lat"], tohoku["lon"], use_box, box_deg)

# Plot
fig, ax = plt.subplots(figsize=(8, 3.5), dpi=300)
ax.plot(ts_JPL_sumatra["time"].values, ts_JPL_sumatra.values, lw=1.2, 
        color = "#2b8cbe", label=f'JPL{sumatra["label"]}')
ax.plot(ts_JPL_tohoku["time"].values,  ts_JPL_tohoku.values,  lw=1.2, 
        color = "#f03b20", label=f'JPL{tohoku["label"]}')
ax.plot(ts_GRAVIS_sumatra["time"].values, ts_GRAVIS_sumatra.values, ls='--', 
        lw=1.2, color = "#2b8cbe", label=f'GRAVIS{sumatra["label"]}')
ax.plot(ts_GRAVIS_tohoku["time"].values,  ts_GRAVIS_tohoku.values,  ls='--', 
        lw=1.2, color = "#f03b20", label=f'GRAVIS{tohoku["label"]}')

# Event markers
ax.axvline(np.datetime64("2004-12-26"), color="tab:red",   ls="--", lw=1, alpha=0.8)
ax.axvline(np.datetime64("2011-03-11"), color="tab:purple", ls="--", lw=1, alpha=0.8)

ax.set_xlabel("Time")
ax.set_ylabel("OPB (cm)")
#ax.set_title("Time series near epicenters: 2004 Sumatra & 2011 Tōhoku")
ax.legend(frameon=True)
ax.grid(True, ls="--", lw=0.4, alpha=0.5)
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m"))

plt.tight_layout()
#plt.savefig("EQ_TS.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter

# Your DataArray: (time, latitude, longitude), lon in [-180, 180]
da_JPL = aligned_GRACE_pp_rg
da_GRAVIS = aligned_GRAVIS_GRACE_pp_rg

# Epicenters (approx.)
sumatra = dict(lat=6,  lon=90,  label="~2004 Sumatra–Andaman Epicenter")
tohoku  = dict(lat=34, lon=147, label="~2011 Tōhoku Epicenter")

# OPTIONAL: small spatial average around the epicenter instead of nearest cell
use_box = False
box_deg = 0.5  # +/- degrees for averaging if use_box=True

def extract_ts(da, lat, lon, use_box=False, box=0.5):
    if use_box:
        return da.sel(latitude=slice(lat-box, lat+box),
                      longitude=slice(lon-box, lon+box)).mean(("latitude","longitude"))
    else:
        return da.sel(latitude=lat, longitude=lon, method="nearest")

ts_JPL_sumatra = extract_ts(da_JPL, sumatra["lat"], sumatra["lon"], use_box, box_deg)
ts_GRAVIS_sumatra = extract_ts(da_GRAVIS, sumatra["lat"], sumatra["lon"], use_box, box_deg)

ts_JPL_tohoku = extract_ts(da_JPL, tohoku["lat"], tohoku["lon"], use_box, box_deg)
ts_GRAVIS_tohoku = extract_ts(da_GRAVIS, tohoku["lat"], tohoku["lon"], use_box, box_deg)

# Plot
fig, ax = plt.subplots(figsize=(8, 3.5), dpi=300)
ax.plot(ts_JPL_sumatra["time"].values, ts_JPL_sumatra.values, lw=1.2, 
        color = "#2b8cbe", label=f'JPL{sumatra["label"]}')
ax.plot(ts_JPL_tohoku["time"].values,  ts_JPL_tohoku.values,  lw=1.2, 
        color = "#f03b20", label=f'JPL{tohoku["label"]}')
ax.plot(ts_GRAVIS_sumatra["time"].values, ts_GRAVIS_sumatra.values, ls='--', 
        lw=1.2, color = "#2b8cbe", label=f'GRAVIS{sumatra["label"]}')
ax.plot(ts_GRAVIS_tohoku["time"].values,  ts_GRAVIS_tohoku.values,  ls='--', 
        lw=1.2, color = "#f03b20", label=f'GRAVIS{tohoku["label"]}')

# Event markers
ax.axvline(np.datetime64("2004-12-26"), color="tab:red",   ls="--", lw=1, alpha=0.8)
ax.axvline(np.datetime64("2011-03-11"), color="tab:purple", ls="--", lw=1, alpha=0.8)

ax.set_xlabel("Time")
ax.set_ylabel("OPB (cm)")
#ax.set_title("Time series near epicenters: 2004 Sumatra & 2011 Tōhoku")
ax.legend(frameon=True)
ax.grid(True, ls="--", lw=0.4, alpha=0.5)
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m"))

plt.tight_layout()
#plt.savefig("EQ_TS.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
grace_min_v4r5 = GRAVIS_GRACE_common - GRAVIS_ECCO_v4r5_common
grace_min_ctrl = GRAVIS_GRACE_common - GRAVIS_ECCO_ctrl_common
ctrl_min_v4r5 = GRAVIS_ECCO_ctrl_common - GRAVIS_ECCO_v4r5_common

In [None]:
%%time

# Write a for loop to compute all interested EOFs.
ds = [GRAVIS_GRACE_common, grace_min_v4r5, grace_min_ctrl]

titles = ["GRACE Raw (GravIS) EOFs",
          "GRACE - v4r5 (GravIS) EOFs", "GRACE - ctrl (GravIS) EOFs"]

solvers = []
eofs =[]
pcs = []
variance_fractions = []
pcs_norm = []

for i in range(len(ds)):

    temp = ds[i]
    if temp.dims[0] != 'time':
        temp = temp.transpose('time', 'latitude', 'longitude')
    
    temp = temp.reset_coords(drop=True)  # keep only dimension coords
    solver = Eof(temp.fillna(0))
    #solvers.append(solver)

    eof = solver.eofsAsCorrelation(neofs=5)  # Leading EOF patterns
    pc = solver.pcs(npcs=5, pcscaling=1)     # Corresponding PCs
    variance_fraction = solver.varianceFraction(neigs=40)  # Variance explained

    pc_norm = pc / pc.std(dim='time')

    eofs.append(eof)
    pcs.append(pc)
    variance_fractions.append(variance_fraction)
    pcs_norm.append(pc_norm)

    


In [None]:
for i in range(len(ds)):
    name = titles[i]
    plot_eofs_with_pcs(
        eofs=eofs[i],                         # your EOF maps (with lat/lon coords)
        pcs=pcs_norm[i],                # 2D (time x modes)
        title=name,
        n_modes=5,
        variance_fraction=variance_fractions[i],  # optional
        save_path = f"./eof_figs/{name}.png"
    )
    




In [None]:
%%time
# Compute trends

grace_common_trend     = compute_trends_manually(GRAVIS_GRACE_common)
ecco_v4r5_common_trend = compute_trends_manually(GRAVIS_ECCO_v4r5_common)
ecco_ctrl_common_trend = compute_trends_manually(GRAVIS_ECCO_ctrl_common)

grace_min_v4r5_trend = compute_trends_manually(grace_min_v4r5)
grace_min_ctrl_trend = compute_trends_manually(grace_min_ctrl)
ctrl_min_v4r5_trend = compute_trends_manually(ctrl_min_v4r5)



In [None]:
grace_common_trend_pyr = grace_common_trend * 365
ecco_v4r5_common_trend_pyr = ecco_v4r5_common_trend * 365
ecco_ctrl_common_trend_pyr = ecco_ctrl_common_trend * 365
grace_min_v4r5_trend_pyr = grace_min_v4r5_trend * 365
grace_min_ctrl_trend_pyr = grace_min_ctrl_trend * 365
ctrl_min_v4r5_trend_pyr = ctrl_min_v4r5_trend * 365

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt

# Define the datasets and titles
data_list = [grace_common_trend_pyr, 
             ecco_v4r5_common_trend_pyr, 
             ecco_ctrl_common_trend_pyr,
            grace_min_v4r5_trend_pyr,
            grace_min_ctrl_trend_pyr,
            ctrl_min_v4r5_trend_pyr]

title_list = ["GRACE Trends", 
              "ECCO v4r5 Trends", 
              "ECCO ctrl Trends",
             "GRACE - v4r5 Trends",
             "GRACE - ctrl Trends",
             "ctrl - v4r5 Trends"]

labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 4.5), dpi=500, subplot_kw={'projection': ccrs.PlateCarree(central_longitude=0)})

for i, (ax, data, title) in enumerate(zip(axes.flat, data_list, title_list)):
    pc = ax.pcolormesh(GRAVIS_GRACE_common.longitude, 
                       GRAVIS_GRACE_common.latitude, 
                       data,
                       transform=ccrs.PlateCarree(), 
                       cmap='coolwarm',
                       vmin=-0.3, 
                       vmax=0.3)
    ax.set_global() 
    
    # Add gridlines
    gl = ax.gridlines(color='gray', linestyle='dashed', linewidth=0.01)
    if i in [0,1,2]:
        gl.xlabels_bottom = False
    else:
        gl.xlabels_bottom = True

    if i in [1,2,4,5]:
        gl.ylabels_left = False
    else:
        gl.ylabels_left = True
        
    gl.gridlines = False
    
    # Add land feature
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                                linewidth=0.5,
                                                edgecolor='black',
                                                facecolor='darkgray'))
    
    ax.set_title(title, fontsize=12)

     # Add (a)-(f) label to top-left
    ax.text(-0.01, 1.1, labels[i], transform=ax.transAxes,
            fontsize=8, fontweight='bold', va='top', ha='left',
            bbox=dict(facecolor='white', edgecolor='none', pad=1.5, alpha=0.6))


# Add colorbar below the bottom figure
cbar_ax = fig.add_axes([0.1, 0.08, 0.8, 0.025])
cbar = fig.colorbar(pc, cax=cbar_ax, orientation='horizontal')
cbar.set_label("Ocean Bottom Pressure Trend (cm/yr)")

plt.tight_layout(rect=[0, 0.1, 1, 1])
#plt.savefig("OBP_Trends_Fig_GRACEV2.png", dpi=500, bbox_inches='tight')
plt.show()


In [None]:
grace_min_v4r5.std(dim='time').plot(robust=True)

In [None]:
GRAVIS_GRACE_common.std(dim='time').plot(vmin=0, vmax=5)

### Roughwork

In [None]:
# Read in ECCO ctrl data
pb_r5_mon = np.zeros((len(r5_month),13,90,90))*np.nan
#maskW = mask
for i in tqdm(np.arange(len(data_list))):
    data = ecco.read_llc_to_tiles(r5_dir, data_list[i],less_output=True,nk=-1)*100;
    data = data[1,:,:,:]
    #data = data*maskW
    pb_r5_mon[i,:,:,:] = data

In [None]:
pb_r5_mon[pb_r5_mon==0] = np.nan

In [None]:
total_ocn_area = np.nansum(eccor5_grid.rA*mask) 

In [None]:
# global-mean of pb_r4_mon, weighted by area
total_ocn_area = np.nansum(eccor5_grid.rA*mask)  # global-mean for ocean area, land is zero
weight = eccor5_grid.rA/total_ocn_area
weight_expanded = np.expand_dims(weight, axis=0)
weight = np.tile(weight_expanded, (pb_r5_mon.shape[0],1,1,1))
glob_mean = np.nansum(np.nansum(np.nansum(pb_r5_mon*weight,axis=-1),axis=-1),axis=-1)
glob_mean_eccor5 = np.copy(glob_mean)
glob_mean = np.tile(glob_mean[:,np.newaxis,np.newaxis,np.newaxis],(1,13,90,90))

In [None]:
glob_mean.shape

In [None]:
# Save r5_with_ext_month to an xarray dataset

# Extract coordinates from the grid
i = eccor5_grid['i']  # Shape: (90,)
j = eccor5_grid['j']  # Shape: (90,)
tile = eccor5_grid['tile']  # Shape: (13,)
time=(["time"],r5_month)

ds = xr.Dataset(
    data_vars=dict(
        pb=(["time","tile", "j", "i"], pb_r5_mon, {"units": "cm"}),
    ),
    coords=dict(
        time=time,
        tile=tile,
        j=j,
        i=i
    ),
    attrs=dict(description="ECCO pb, cm", units='cm'),
)

# Save the dataset to a NetCDF file
#ds.to_netcdf(os.path.join(output_dir, 'pb_ECCO_ctrl.nc'))

In [None]:
# Generate the new time coordinate
new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
ds = ds.assign_coords(time=new_time)

In [None]:
ecco_ctrl_ts = ds.pb.mean(dim=['tile', 'j', 'i'], skipna=True)
ecco_ctrl_ts.plot()

In [None]:
ecco_ctrl_ts = ecco_ctrl_ts.sel(time = slice("2002-04-17", "2019-12-01"))
temp = ecco_ctrl_ts.mean(dim = "time")
ecco_ctrl_ts_anom = ecco_ctrl_ts - temp

In [None]:
v4r5_raw = xr.open_dataset("/glade/work/mengnanz/V4r5/GMSL.nc")
v4r5_raw = v4r5_raw*100
v4r5_raw = v4r5_raw.assign_coords(time = new_time)
v4r5_raw = v4r5_raw.sel(time = slice("2002-04-17", "2019-12-01"))
temp = v4r5_raw.mean(dim = "time")
v4r5_raw_anom = v4r5_raw - temp

In [None]:
v4r5_raw = v4r5_raw*100

In [None]:
v4r5_raw = v4r5_raw.assign_coords(time = new_time)
v4r5_raw = v4r5_raw.sel(time = slice("2002-04-17", "2019-12-01"))
temp = v4r5_raw.mean(dim = "time")
v4r5_raw_anom = v4r5_raw - temp

In [None]:
grace_ts = xr.open_dataset("ocean_mass_GRACE_200204_202502.nc")

In [None]:
grace_ts

In [None]:
grace_ts = xr.open_dataset("ocean_mass_GRACE_200204_202502.nc")

import numpy as np
import pandas as pd

def decimal_years_to_actual_date_midday(decimal_years):
    years = np.floor(decimal_years).astype(int)
    remainders = decimal_years - years

    # Compute number of days in the year (handles leap years)
    start_of_year = pd.to_datetime(years, format='%Y')
    start_of_next_year = pd.to_datetime(years + 1, format='%Y')
    days_in_year = (start_of_next_year - start_of_year).days

    # Round to nearest whole day
    day_offsets = np.round(remainders * days_in_year).astype(int)

    # Construct datetime with 12:00 PM and second-level precision
    datetimes = [
        pd.Timestamp(year=year, month=1, day=1) + pd.Timedelta(days=offset) + pd.Timedelta(hours=12)
        for year, offset in zip(years, day_offsets)
    ]

    # Return without fractional seconds
    return pd.to_datetime(datetimes).astype("datetime64[s]")

# Apply to GRACE xarray dataset
grace_ts = grace_ts.assign_coords(time=decimal_years_to_actual_date_midday(grace_ts.time.values))

grace_ts = grace_ts /10
grace_ts = grace_ts.sel(time = slice("2002-04-17", "2019-12-01"))
avg = grace_ts.mean()
grace_ts_anom = grace_ts - avg

In [None]:
grace_ts = grace_ts /10

In [None]:
grace_ts = grace_ts.sel(time = slice("2002-04-17", "2019-12-01"))

In [None]:
avg = grace_ts.mean()
grace_ts_anom = grace_ts - avg

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

fig, ax4 = plt.subplots(figsize=(11, 5), dpi=120)

ax4.plot(v4r5_raw_anom.time, v4r5_raw_anom.global_mean_barystatic_sea_level_anomaly,
         label='ECCO v4r5', linestyle='-', linewidth=2, color="#ff7f00")
ax4.plot(ecco_ctrl_ts_anom.time, ecco_ctrl_ts_anom,
         label='ECCO Control', linestyle='-', linewidth=2, color="#33a02c")
ax4.plot(grace_ts_anom.time, grace_ts_anom.ocean_mass,
         label='GRACE', linestyle='', marker='.', markersize=5, color="#1f78b4")

ax4.set_ylabel("Ocean Bottom Pressure Spatial Avg. (cm)", fontweight='bold')
ax4.set_xlim(pd.Timestamp("2002-04-01"), pd.Timestamp("2019-12-01"))

# Nice date ticks
ax4.xaxis.set_major_locator(mdates.YearLocator(2))
ax4.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax4.grid(True, ls='--', lw=0.4, alpha=0.6)
ax4.legend(frameon=True, ncol=3)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

fig, ax4 = plt.subplots(figsize=(11, 5), dpi=120)

ax4.plot(v4r5_raw_anom.time, v4r5_raw_anom.global_mean_barystatic_sea_level_anomaly,
         label='ECCO v4r5', linestyle='-', linewidth=2, color="#ff7f00")
ax4.plot(ecco_ctrl_ts_anom.time, ecco_ctrl_ts_anom,
         label='ECCO Control', linestyle='-', linewidth=2, color="#33a02c")
ax4.plot(grace_ts_anom.time, grace_ts_anom.ocean_mass,
         label='GRACE', linestyle='', marker='.', markersize=5, color="#1f78b4")
ax4.plot(glob_mean_GRACE.time, glob_mean_GRACE,
         label='GRACE V2', linestyle='', marker='.', markersize=5, color="red")

ax4.set_ylabel("Ocean Bottom Pressure Spatial Avg. (cm)", fontweight='bold')
ax4.set_xlim(pd.Timestamp("2002-04-01"), pd.Timestamp("2019-12-01"))

# Nice date ticks
ax4.xaxis.set_major_locator(mdates.YearLocator(2))
ax4.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax4.grid(True, ls='--', lw=0.4, alpha=0.6)
ax4.legend(frameon=True, ncol=3)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

# --- figure & axes (smaller, compact) ---
fig, ax = plt.subplots(figsize=(5.4, 2.8), dpi=500)

# --- series ---
ax.plot(
    v4r5_raw_anom.time,
    v4r5_raw_anom.global_mean_barystatic_sea_level_anomaly,
    label="ECCO v4r5",
    color="#ff7f00", lw=1.6, ls="-"
)
ax.plot(
    ecco_ctrl_ts_anom.time,
    ecco_ctrl_ts_anom,
    label="ECCO Control",
    color="#33a02c", lw=1.6, ls="-"
)
ax.plot(
    grace_ts_anom.time,
    grace_ts_anom.ocean_mass,
    label="GRACE",
    color="#1f78b4", lw=0, marker="o", ms=2.8, mew=0
)

# --- axes formatting ---
ax.set_ylabel("Ocean Bottom Pressure Spatial Avg. (cm)", fontsize=9)
ax.set_xlim(pd.Timestamp("2002-04-01"), pd.Timestamp("2019-12-01"))
ax.set_ylim(-12, 16)   # adjust if needed

# ticks & grid
ax.xaxis.set_major_locator(mdates.YearLocator(2))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.tick_params(axis="both", labelsize=8)
ax.grid(True, ls="--", lw=0.4, alpha=0.6)

# legend (compact, top-right)
leg = ax.legend(
    loc="upper right", frameon=True, fontsize=8,
    handlelength=2.2, borderaxespad=0.6
)
leg.get_frame().set_alpha(0.9)

plt.tight_layout()
plt.show()


### V2 diffs

In [None]:
aligned_grace = xr.open_dataset(os.path.join(input_dir, 'aligned_grace.nc'))


In [None]:
#convert GRACE dates to normal format

time_vals = np.array(aligned_grace.time)

# Convert decimal years to datetime format
base_year = np.floor(time_vals).astype(int)  # Extract the integer year part

fractional_part = time_vals - base_year      # Get the fractional part of the year

mon = (fractional_part*12)+1
mon = mon.round(0)

# Convert year and month to string before concatenation
year_str = base_year.astype(str)
month_str = np.char.zfill(mon.astype(str), 2)  # Ensure two-digit month formatting

# Proper string concatenation using np.char.add
date_str = np.char.add(np.char.add(year_str, '-'), np.char.add(month_str, '-01'))

# Convert to datetime
dates = pd.to_datetime(date_str)

aligned_grace = aligned_grace.assign_coords(time=dates)

In [None]:
eccor5_grid = ecco.load_ecco_grid_nc(input_dir_ecco_grid, 'ECCO-GRID.nc')

In [None]:
# GRACE put into common grid

import xarray as xr
import numpy as np
import ecco_v4_py as ecco  # Ensure the ECCO tools are installed

# Input dataset and grid
data = aligned_grace.pb  # Example data variable
grid = eccor5_grid

# Target resolution for latitude and longitude
new_grid_delta_lat = 0.5  # Latitude resolution (degrees)
new_grid_delta_lon = 0.5  # Longitude resolution (degrees)

# Define global latitude and longitude bounds
new_grid_min_lat, new_grid_max_lat = -90, 90
new_grid_min_lon, new_grid_max_lon = -180, 180

# Initialize an empty list to store time steps
global_data_list = []

# Iterate over the time dimension and regrid each time step
for t in range(data.sizes['time']):
    # Select the time slice
    tmp = data.isel(time=t)

    # Mask invalid points using hFacC
    tmp = tmp.where(grid.hFacC.isel(k=0) != 0)

    # Regrid to a latitude-longitude grid
    _, _, _, _, regridded_data = ecco.resample_to_latlon(
        grid.XC, grid.YC, tmp, 
        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,
        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,
        mapping_method='nearest_neighbor',  # Use nearest neighbor for simplicity
        fill_value=np.nan
    )

    # Append the regridded data to the list
    global_data_list.append(regridded_data)

# Stack the list into a single 3D array
global_data_array = np.stack(global_data_list, axis=0)

# Define latitude and longitude arrays based on the shape of the global_data_array
lat = np.linspace(new_grid_min_lat, new_grid_max_lat, global_data_array.shape[1])  # 360 points
lon = np.linspace(new_grid_min_lon, new_grid_max_lon, global_data_array.shape[2])  # 720 points

# Create the xarray.DataArray
regridded_data_da = xr.DataArray(
    global_data_array,
    dims=['time', 'latitude', 'longitude'],
    coords={'time': data.time, 'latitude': lat, 'longitude': lon},
    name='pb'
)


import pandas as pd

# Generate the new time coordinate
#new_time = pd.date_range(start="1992-01-01", end="2019-12-31", freq="MS")  # MS = Month Start

# Ensure the length of new_time matches the time dimension of the DataArray
#assert len(new_time) == regridded_data_da.sizes['time'], "Mismatch in time dimension size!"

# Replace the time coordinate
aligned_GRACE_pp_rg = regridded_data_da


In [None]:
# Extract GRACE time coordinate
grace_time = aligned_GRACE_pp_rg.time

# Reindex ECCO datasets to GRACE time
ECCO_pb_v4r5_common = ECCO_pb_v4r5_pp_rg.sel(time=slice(grace_time.min(), grace_time.max()))
ECCO_pb_ctrl_common = ECCO_pb_ctrl_pp_rg.sel(time=slice(grace_time.min(), grace_time.max()))

# Optional: ensure exact matching times (intersection, in case ECCO has extra months)
common_time = np.intersect1d(ECCO_pb_v4r5_common.time.values,
                             aligned_GRACE_pp_rg.time.values)

ECCO_pb_v4r5_common = ECCO_pb_v4r5_common.sel(time=common_time)
ECCO_pb_ctrl_common = ECCO_pb_ctrl_common.sel(time=common_time)
GRACE_common         = aligned_GRACE_pp_rg.sel(time=common_time)

In [None]:
grace_stdV2 = GRACE_common.std(dim='time', skipna=True).load()

In [None]:
dif = grace_stdV2-grace_std

In [None]:
ax = dif.plot(robust=True)
ax.set_title("Std. V2 - Std. V3", fontsize=12, fontweight="bold", loc="left", pad=6)