In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import os, re, glob, sys
from llc_uv_shift import llc_uv_shift
import temp_anom_budgets as tab
from dask import delayed, compute
from cal_wmb_llc import vertical_pairwise_avg
import time
from dask.distributed import Client, LocalCluster
import xgcm

start_time = time.time()

# === 基本配置 ===
data_dir = '/dfs9/hfdrake_hpc/datasets/ECCOv4r4/'
grid_path = 'GRID_GEOMETRY_ECCO_V4r4_native_llc0090.nc'
mask_path = 'goa_mask_llc90.nc'
rho_pot = 1029.0
rhoconst = 1029.0
face_connections = {'face':
                    {0: {'X':  ((12, 'Y', False), (3, 'X', False)),
                         'Y':  (None,             (1, 'Y', False))},
                     1: {'X':  ((11, 'Y', False), (4, 'X', False)),
                         'Y':  ((0, 'Y', False),  (2, 'Y', False))},
                     2: {'X':  ((10, 'Y', False), (5, 'X', False)),
                         'Y':  ((1, 'Y', False),  (6, 'X', False))},
                     3: {'X':  ((0, 'X', False),  (9, 'Y', False)),
                         'Y':  (None,             (4, 'Y', False))},
                     4: {'X':  ((1, 'X', False),  (8, 'Y', False)),
                         'Y':  ((3, 'Y', False),  (5, 'Y', False))},
                     5: {'X':  ((2, 'X', False),  (7, 'Y', False)),
                         'Y':  ((4, 'Y', False),  (6, 'Y', False))},
                     6: {'X':  ((2, 'Y', False),  (7, 'X', False)),
                         'Y':  ((5, 'Y', False),  (10, 'X', False))},
                     7: {'X':  ((6, 'X', False),  (8, 'X', False)),
                         'Y':  ((5, 'X', False),  (10, 'Y', False))},
                     8: {'X':  ((7, 'X', False),  (9, 'X', False)),
                         'Y':  ((4, 'X', False),  (11, 'Y', False))},
                     9: {'X':  ((8, 'X', False),  None),
                         'Y':  ((3, 'X', False),  (12, 'Y', False))},
                     10: {'X': ((6, 'Y', False),  (11, 'X', False)),
                          'Y': ((7, 'Y', False),  (2, 'X', False))},
                     11: {'X': ((10, 'X', False), (12, 'X', False)),
                          'Y': ((8, 'Y', False),  (1, 'X', False))},
                     12: {'X': ((11, 'X', False), None),
                          'Y': ((9, 'Y', False),  (0, 'X', False))}}}

prefix = {
    'ts':      'OCEAN_TEMPERATURE_SALINITY_snap/OCEAN_TEMPERATURE_SALINITY_snap',
    'tsdaily': 'ECCO_L4_TEMP_SALINITY_LLC0090GRID_DAILY_V4R4/OCEAN_TEMPERATURE_SALINITY_day_mean_',
    'tadv':    'ECCO_L4_OCEAN_3D_TEMPERATURE_FLUX_LLC0090GRID_DAILY_V4R4/OCEAN_3D_TEMPERATURE_FLUX_day_mean_',
    'hflux':   'ECCO_L4_HEAT_FLUX_LLC0090GRID_DAILY_V4R4/OCEAN_AND_ICE_SURFACE_HEAT_FLUX_day_mean_',
    'ssh':     'SEA_SURFACE_HEIGHT_snap/SEA_SURFACE_HEIGHT_snap_',
    'sshdaily': 'ECCO_L4_SSH_LLC0090GRID_DAILY_V4R4/SEA_SURFACE_HEIGHT_day_mean_',
    'sflux':   'ECCO_L4_FRESH_FLUX_LLC0090GRID_DAILY_V4R4/OCEAN_AND_ICE_SURFACE_FW_FLUX_day_mean_',
    'volflux': 'ECCO_L4_OCEAN_3D_VOLUME_FLUX_LLC0090GRID_DAILY_V4R4/OCEAN_3D_VOLUME_FLUX_day_mean_',
    'vstar':   'ECCO_L4_BOLUS_LLC0090GRID_DAILY_V4R4/OCEAN_BOLUS_VELOCITY_day_mean_'
}

clim_dir = data_dir
#clim_files = {
#    "tadv":            "tadv_clim_1992-2017_30dMA_allDays.nc",
#    "hflux":          "hflux_clim_1992-2017_30dMA_allDays.nc",
#     "sflux":         "sflux_clim_1992-2017_30dMA_allDays.nc",
#     "tsdaily":     "tsdaily_clim_1992-2017_30dMA_allDays.nc",
#     "vstar":         "vstar_clim_1992-2017_30dMA_allDays.nc",
#     "vol":             "vol_clim_1992-2017_30dMA_allDays.nc",
#     "tsdailyS":   "tsdailyS_clim_1992-2017_30dMA_allDays.nc",
#     "eddyforcing": "dailyclim_eddyforcing.nc"
# }
clim_files = {
    "tadv":            "tadv_clim_1992-2017_raw_allDays.nc",
    "hflux":          "hflux_clim_1992-2017_raw_allDays.nc",
     "sflux":         "sflux_clim_1992-2017_raw_allDays.nc",
     "tsdaily":     "tsdaily_clim_1992-2017_raw_allDays.nc",
     "vstar":         "vstar_clim_1992-2017_raw_allDays.nc",
     "vol":             "vol_clim_1992-2017_raw_allDays.nc",
     "tsdailyS":   "tsdailyS_clim_1992-2017_raw_allDays.nc",
     "eddyforcing": "dailyclim_eddyforcing_raw.nc"
 }

clim_keep_vars = {
     "tadv":        ["DFxE_TH", "DFyE_TH", "DFrE_TH", "DFrI_TH"],
     "tsdailyS":    ["THETA"],
     "eddyforcing": ["Hnabla_eddy", "Vnabla_eddy"],
     "tsdaily":     ["THETA"]
 }

def sel_and_retime(path: str, mmdd_seq: np.ndarray, time_coord: xr.DataArray, keep=None) -> xr.Dataset:
    ds = xr.open_dataset(path).chunk(chunks={"mmdd": 1})
    if keep:
        ds = ds[keep]
    ds = ds.sel(mmdd=mmdd_seq).drop_vars('mmdd').rename({'mmdd': 'time'})
    return ds.assign_coords(time=time_coord).reset_coords(drop=True)
def sel_and_retime_clim(key: str, mmdd_seq, time_coord):
    path = os.path.join(clim_dir, clim_files[key])
    keep = clim_keep_vars.get(key, None)
    return sel_and_retime(path, mmdd_seq, time_coord, keep=keep)
def align_time_like(da_target: xr.Dataset, da_ref: xr.Dataset) -> xr.Dataset:
    ref_dates = da_ref.time.dt.floor('D').values
    ref_times = da_ref.time.values
    mapping = dict(zip(ref_dates, ref_times))
    tgt_dates = da_target.time.dt.floor('D').values
    new_times = np.array([mapping.get(d, t) for d, t in zip(tgt_dates, da_target.time.values)])
    da_target = da_target.copy()
    da_target['time'] = new_times
    return da_target
# 查找下月1号的文件
def find_next_day_file(prefix_path, next_day):
    date_str = next_day.strftime('%Y-%m-%d')
    pattern = os.path.join(data_dir, f"{prefix_path}*{date_str}T000000*.nc")
    files = glob.glob(pattern)
    return files[0] if files else None
    
def get_month_str(filename):
    patterns = [
        r'_(\d{4})-(\d{2})-\d{2}T\d{6}_',  # 有T的格式
        r'_(\d{4})-(\d{2})-\d{2}_'         # 无T的格式
    ]
    for pattern in patterns:
        match = re.search(pattern, filename)
        if match:
            return f"{match.group(1)}{match.group(2)}"
    raise ValueError(f"No valid date found in filename: {filename}")


# === 时间设置 ===
#
month = '199202'
# 从命令行获取年月，格式如199201
#if len(sys.argv) < 2:
#    raise ValueError("请提供 YYYYMM 参数")
#month = sys.argv[1]
#outputdir = f'G_global_wmb_anomaly_{month}.nc'
outputdir = f'G_goa_wmb_anomaly_{month}.nc'

# 如果输出文件已存在，则跳过计算
if os.path.exists(outputdir):
    print(f"✅ 输出文件已存在：{outputdir}，跳过计算")
    os.remove(outputdir)
    #sys.exit(0)  # 或者 return / exit，根据你是否在函数内

this_month = pd.to_datetime(month, format='%Y%m')
next_month_first_day = this_month + pd.DateOffset(months=1)

# 主文件收集逻辑
day_files = {}
has_next_month_data = {}

for key, pref in prefix.items():
    # 当前月的所有文件
    all_files = sorted(glob.glob(os.path.join(data_dir, pref + '*.nc')))
    current_month_files = [f for f in all_files if get_month_str(f) == month]

    # 仅对 ts 和 ssh 添加下月第一天
    if key in ['ts', 'ssh']:
        next_day_file = find_next_day_file(pref, next_month_first_day)
        if next_day_file:
            current_month_files.append(next_day_file)
            has_next_month_data[key] = True
        else:
            print(f"{month} ➤ {key}: ❌ 没有下月1号的数据")
            has_next_month_data[key] = False
    day_files[key] = current_month_files


# === 网格与掩码 ===
ECCOgrid = xr.open_dataset(grid_path).rename({'tile': 'face'}).load()
mask = xr.open_dataarray(mask_path).load()
rA, drF, hFacC, Depth = ECCOgrid['rA'], ECCOgrid['drF'], ECCOgrid['hFacC'], ECCOgrid['Depth']
vol = (rA * drF * hFacC).transpose('face', 'k', 'j', 'i').astype('float64')

# === 懒加载数据 ===
print(f"Processing {month} ...")


open3d = dict(chunks={'time':1, 'k':50, 'k_l':50, 'face':13, 'j':90, 'i':90, 'i_g':90, 'j_g':90, 'tile':13, 'face': 13})
open2d = dict(chunks={'time':1, 'face':13, 'j':90, 'i':90, 'i_g':90, 'j_g':90, 'tile':13, 'face': 13})

#
testdays = 2
#

def preprocess(ds):
    return ds.rename({'tile':'face'})

tscache  = xr.open_mfdataset(day_files['ts'][0:testdays+1], preprocess=preprocess, chunks=open3d["chunks"])
tscache  = tscache.assign_coords( time=tscache.time.dt.floor("D") + pd.Timedelta(hours=12) )
print("load ts done!")
tsdaily  = xr.open_mfdataset(day_files['tsdaily'][0:testdays], preprocess=preprocess, chunks=open3d["chunks"])
#tsdaily = tsdaily.reset_coords(drop=True)
print("load tsdaily done!")
tbudget  = xr.open_mfdataset(day_files['tadv'][0:testdays], preprocess=preprocess, chunks=open3d["chunks"])
print("load tadv done!")
hflux    = xr.open_mfdataset(day_files['hflux'][0:testdays], preprocess=preprocess, chunks=open2d["chunks"])
print("load hflux done!")
ssh      = xr.open_mfdataset(day_files['ssh'][0:testdays+1], preprocess=preprocess, chunks=open2d["chunks"])
ssh      = ssh.assign_coords( time=ssh.time.dt.floor("D") + pd.Timedelta(hours=12) )
print("load ssh done!")
sshdaily = xr.open_mfdataset(day_files['sshdaily'][0:testdays], preprocess=preprocess, chunks=open2d["chunks"])
print("load sshdaily done!")
sflux    = xr.open_mfdataset(day_files['sflux'][0:testdays], preprocess=preprocess, chunks=open2d["chunks"])
print("load sflux done!")
volflux  = xr.open_mfdataset(day_files['volflux'][0:testdays], preprocess=preprocess, chunks=open3d["chunks"])
print("load volflux done!")
vstar    = xr.open_mfdataset(day_files['vstar'][0:testdays], preprocess=preprocess, chunks=open3d["chunks"])
print("load vstar done!")

#
grid_snap = xgcm.Grid(tscache, periodic=False, face_connections=face_connections )
grid_daily = xgcm.Grid(tsdaily, periodic=False, face_connections=face_connections )

#

# === 时间对齐 ===
#ssh = align_time_like(ssh, tbudget)
#tscache = align_time_like(tscache, tbudget)
# === 气候态数据 ===
mmdd_seq = tbudget.time.dt.strftime('%m%d').values
clim = {k: sel_and_retime_clim(k, mmdd_seq, tbudget.time) for k in clim_files}
mmdd_seq = tscache.time.dt.strftime('%m%d').values
clim_tsdaily = sel_and_retime_clim('tsdaily', mmdd_seq, tscache['THETA'].time)

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")

print("loading complete")
print("preparing Tprime budget ...")



## using client

In [None]:
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask, os, sys

#
my_dashboard_port = 8809#8501 # pick a unique one!
#

os.environ["DASK_DISTRIBUTED__DASHBOARD__ADDRESS"] = f"0.0.0.0:{my_dashboard_port}"
# extra resilience: retry tasks when workers die
dask.config.set({
    "distributed.scheduler.default-task-retries": 3,   # try failed tasks 3x
    "distributed.worker.memory.target": 0.85,          # start spilling earlier
    "distributed.worker.memory.spill": 0.90,
    "distributed.worker.memory.terminate": 0.98,
})

# tiny scheduler + notebook are already on your standard node
# now create *workers* on the free partition
cluster = SLURMCluster(
    queue="free",                 # <- partition for workers
    walltime="03:00:00",          # preemptible jobs usually shorter; autoscale will re-submit
    cores=2,                      # per worker *job* (processes*threads below should ~= cores)
    processes=1,                  # n workers per job
    memory="32GB",                # total per job
    local_directory="/tmp",       # fast node-local spill
    job_extra_directives=[
        "--qos=free-part",
        "--requeue",
        "--signal=TERM@120",
    ],
    python=sys.executable,
    job_script_prologue=[
        "set -euo pipefail",
        "export HDF5_USE_FILE_LOCKING=FALSE",
        "export DASK_TEMPORARY_DIRECTORY=/tmp",
    ],
    scheduler_options={"dashboard_address": f"0.0.0.0:{my_dashboard_port}"},
)

client = Client(cluster)

# Elastic scale: grow to use free cycles; shrink when idle
cluster.adapt(minimum_jobs=1, maximum_jobs=24, wait_count=3, interval="10s")

print("Scheduler:", client)
print("Dashboard:", client.dashboard_link)

## calculation of advection

In [None]:
start_time = time.time()

# 7) 计算各项 —— 全部保持 lazy 直到写文件
# 7.1 体积（含 ETAN 拉伸）
s_star = 1.0 + ssh.ETAN / Depth  # 无量纲
Mass = (s_star * (vol * rho_pot)).transpose('time', 'k', 'face', 'j', 'i')  # kg


# 7.2 表面体积通量（m^3/s），由淡水通量（m/s）× 面积
freshwater = (sflux.oceFWflx - clim['sflux'].oceFWflx) * rA / rhoconst  # m^3/s

# 7.3 质量平流 & GM 体积通量（m^3/s）
chunk_dict = dict(zip(tsdaily.THETA.dims, [c[0] for c in tsdaily.THETA.chunks]))
template = xr.Dataset({
    'utrans': xr.zeros_like(tsdaily.THETA ).chunk(chunk_dict),
    'utrans_right': xr.zeros_like(tsdaily.THETA ).chunk(chunk_dict),
    'vtrans': xr.zeros_like(tsdaily.THETA ).chunk(chunk_dict),
    'vtrans_up': xr.zeros_like(tsdaily.THETA ).chunk(chunk_dict), })

# 并行计算
ECCOgriduse = xr.Dataset({
    'drF': ECCOgrid['drF'].astype('float64'),
    'dyG': ECCOgrid['dyG'].astype('float64'),
    'dxG': ECCOgrid['dxG'].astype('float64'), })
g_adv = xr.map_blocks(
    tab.save_G_adv_surfacevolume_ds,
    volflux,
    kwargs={'ECCOgrid': ECCOgriduse, 'face_connections': face_connections, 'var_prefix': 'MASS'},
    template=template.drop_vars(["XC", "YC", "Z"]))     # m^3/s
g_adv = g_adv.load()

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")

## calculation of U_prime_T_bar

In [None]:
start_time = time.time()

## residual velocity
ures = volflux.UVELMASS + vstar.UVELSTAR
vres = volflux.VVELMASS + vstar.VVELSTAR
wres = volflux.WVELMASS + vstar.WVELSTAR

uprime = (ures - (clim['vol'].UVELMASS + clim['vstar'].UVELSTAR)).rename('UVELMASS_anom')
vprime = (vres - (clim['vol'].VVELMASS + clim['vstar'].VVELSTAR)).rename('VVELMASS_anom')
wprime = (wres - (clim['vol'].WVELMASS + clim['vstar'].WVELSTAR)).rename('WVELMASS_anom')

# prepare template
chunk_dict = dict(zip(tsdaily.THETA.dims, [c[0] for c in tsdaily.THETA.chunks]))
template = xr.Dataset( { "G_Hadv": xr.zeros_like(tsdaily["THETA"]).chunk(chunk_dict),
                         "G_Vadv": xr.zeros_like(tsdaily["THETA"]).chunk(chunk_dict), })
# uprime Tbar
ds_in = xr.Dataset({   "THETA":  clim["tsdaily"].THETA,
                       "sTHETA": clim["tsdailyS"].THETA,
                       "u": uprime, "v": vprime, "w": wprime } )
G_upTb = xr.map_blocks( tab.cal_GMlike_prime_transport_ds, ds_in,
             kwargs={"face_connections": face_connections, "ECCOgrid": ECCOgrid},
             template=template.drop_vars(["XC", "YC", "Z"]) ) # oC * m^3/s
G_upTb.load()

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")


## calculation of Diffusion

In [None]:
start_time = time.time()

#
vars_needed = list(clim['tadv'].data_vars)
tbudget_prime = tbudget[vars_needed] - clim['tadv']  # oC/s（或与源数据单位一致）
template = xr.Dataset({
    "dif_hConvH": xr.zeros_like(tsdaily["THETA"]).chunk(chunk_dict),
    "dif_vConvH": xr.zeros_like(tsdaily["THETA"]).chunk(chunk_dict), })
G_diff = xr.map_blocks( tab.cal_T_diffusion_ds, tbudget_prime,
            kwargs={"face_connections": face_connections},
            template=template.drop_vars(["XC", "YC", "Z"]) )
G_diff = G_diff.load()
#

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")


## calculation of heat forcing
#### note that Geoheating is 0 since it doesn't have anomaly

In [None]:
start_time = time.time()

# 7.7 表面热通量强迫（先做去气候态）
vars_needed = ['oceQsw','TFLUX']
hflux_prime = hflux[vars_needed] - clim['hflux'][vars_needed]     # W/m^2（或源数据单位）
G_forcing = tab.cal_T_forcing(hflux_prime, ECCOgrid, GEOFLX=0.0)  # oC/s
# 把 oC/s × 体积 -> oC * m^3/s（与其他通量单位一致后再乘 rho_pot 变成 oC * kg/s）
G_heat = (G_forcing * vol).transpose('time', 'k', 'face', 'j', 'i')  # oC * m^3/s
#
G_heat = G_heat.load()

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")


## collecting all anomaly budgets and deal with units

In [None]:
start_time = time.time()

# 8) 汇总到 Dataset（统一乘上 rho_pot，得到 kg/s 或 oC*kg/s）
ds_budgets = xr.Dataset(
     data_vars=dict(
            g_fw      =freshwater * rho_pot,
            g_UpTb_h  =G_upTb.G_Hadv * rho_pot,
            g_UpTb_v  =G_upTb.G_Vadv * rho_pot,
            #g_UbTp_h  =G_Hadv_ubar_tprime * rho_pot,
            #g_UbTp_v  =G_Vadv_ubar_tprime * rho_pot,  
            #g_uptp_h  =G_Hadv_uprime_tprime * rho_pot,
            #g_uptp_v  =G_Vadv_uprime_tprime * rho_pot,
            g_UpTp_h  =clim['eddyforcing'].Hnabla_eddy * rho_pot,
            g_UpTp_v  =clim['eddyforcing'].Vnabla_eddy * rho_pot,
            g_mix_h   =G_diff.dif_hConvH * rho_pot,
            g_mix_v   =G_diff.dif_vConvH * rho_pot,
            g_heat    =G_heat  * rho_pot
        ),
        coords=dict(
            time=tbudget.time,
            k=ECCOgrid.k,
            face=ECCOgrid.face,
            j=ECCOgrid.j,
            i=ECCOgrid.i
        ),
        attrs=dict(
            note="All *g_* terms are fluxes on ECCO v4r4 llc90 native grid. "
                 "Units: fluxes [kg/s] or [°C·kg/s]."
        ))

g_adv     =g_adv       * rho_pot  # kg/s

# Mass 单独存
ds_budgets['Mass'] = xr.DataArray(
        Mass,
        dims=('time_mass','k','face','j','i'),
        coords=dict(
            time_mass=Mass.time.values,  # Mass 的 time (比 flux 多一天)
            k=ECCOgrid.k,
            face=ECCOgrid.face,
            j=ECCOgrid.j,
            i=ECCOgrid.i
        ),
        attrs=dict(units="kg", note="State variable with extended time axis for trend calculation.") )

ds_budgets = ds_budgets.load()
#

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")


## calculation of Tprime snapshot and region mask

In [None]:
start_time = time.time()

# wmb using transform
Tsnap_prime = tscache['THETA'] - clim_tsdaily['THETA']
maskC = ECCOgrid.maskC.copy()
region_mask = mask & maskC


## make a temperature coordinate

In [None]:
# 1️⃣ 取出所有非 NaN 温度样本
T_flat = Tsnap_prime.values.flatten()
T_flat = T_flat[~np.isnan(T_flat)]
# 2️⃣ 按分位数确定边界（比如分成 100 个 bin）
n_bins = 100
quantiles = np.linspace(0, 1, n_bins + 1)
t_edges = np.quantile(T_flat, quantiles)
# 3️⃣ 根据边界计算中心值
tcenters = 0.5 * (t_edges[:-1] + t_edges[1:])
dtcenters = np.diff(t_edges)  # 每个 bin 的温度宽度（非均匀）
# 4️⃣ 转成 DataArray
tcen_outer = xr.DataArray(t_edges, dims=['tcenter'])
#dtcenters = xr.DataArray(dtcenters, dims=['itcenter'],coords={"itcenter":np.arange(len(dtcenters))})


# calculation of Tprime snapshot at the interfaces (k+1)

In [None]:
start_time = time.time()

Tsnap_prime_region_outer = grid_snap.interp(Tsnap_prime.ffill(dim="k",limit=1), 'Z', boundary='extend').load()
Tsnap_prime_region_outer_extended = tab.add_bottom_layer_from_cell( Tsnap_prime_region_outer, Tsnap_prime, ECCOgrid.k_p1, ECCOgrid.Zp1)

end_time = time.time()
elapsed = end_time - start_time
print(f"计算耗时：{elapsed:.2f} 秒")


## layered budgets by using transform and do cumulative calculations similar to  [ T > Tcenter]

In [None]:
# === transform terms
layered_results = {name: [] for name in [
        "g_mix_h","g_mix_v","g_UpTb_h","g_UpTb_v","g_UpTp_h","g_UpTp_v",
        "g_heat","g_fw","g_adv","Mass" ]}
#
for varname in ["g_mix_h","g_mix_v","g_heat","g_UpTb_h","g_UpTb_v","g_UpTp_h","g_UpTp_v"]:
    layered_results[varname] = grid_daily.transform( ds_budgets[varname], 'Z', 
                            target=tcen_outer,
                            method='conservative',
                            target_data=Tsnap_prime_region_outer_extended.isel(time=slice(0,-1)))
#
layered_results["Mass"] = grid_snap.transform( ds_budgets["Mass"].rename(time_mass="time"), 'Z', 
                            target=tcen_outer,
                            method='conservative',
                            target_data=Tsnap_prime_region_outer_extended )
#
# expand g_fw
g_fw_expanded = ds_budgets["g_fw"].expand_dims(k=ds_budgets["Mass"].k).transpose("time","k","face","j","i")
# 给其他层赋值为0，只保留表层
g_fw_expanded = g_fw_expanded.where(g_fw_expanded.k == 0, 0)

layered_results["g_fw"] = grid_daily.transform( g_fw_expanded,
                            'Z', 
                            target=tcen_outer,
                            method='conservative',
                            target_data=Tsnap_prime_region_outer_extended.isel(time=slice(0,-1)) )
##
# -----------------------------
layered_cum_budgets = {name: [] for name in [
        "g_mix_h","g_mix_v","g_UpTb_h","g_UpTb_v","g_UpTp_h","g_UpTp_v",
        "g_heat","g_tend","g_fw","g_adv" ]}

for varname in ["g_mix_h","g_mix_v","g_heat","g_UpTb_h","g_UpTb_v","g_UpTp_h","g_UpTp_v"]:
    layered_cum_budgets[varname] = vertical_pairwise_avg(layered_results[varname], layered_results[varname].tcenter, 
                                                         dtcenters,  dim="tcenter", sum_dims=("face", "j", "i"))
#
# cumulative along ">" tcenter
#
dt_seconds = 86400.
Mass_cum = layered_results["Mass"].sum(['i','j','face']).isel(tcenter=slice(None, None, -1)).cumsum(dim='tcenter').isel(tcenter=slice(None, None, -1))
layered_cum_budgets["g_tend"] = Mass_cum.diff('time') / dt_seconds
layered_cum_budgets["g_tend"] = layered_cum_budgets["g_tend"].assign_coords({'time': Mass_cum.time[0:-1]})
#
layered_cum_budgets["g_fw"] = layered_results["g_fw"].sum(['i','j','face']).isel(tcenter=slice(None, None, -1)).cumsum(dim='tcenter').isel(tcenter=slice(None, None, -1))


## see

In [None]:
import matplotlib.pyplot as plt
rhs_vars = [
    'g_mix_h', 'g_mix_v',
    'g_UpTb_h', 'g_UpTb_v',
    'g_UpTp_h', 'g_UpTp_v',
    'g_heat','g_fw']

t_layer = 0.15  # oC
g_tend_layer = layered_cum_budgets['g_tend'].sel(tcenter=t_layer, method='nearest')
rhs_layer = sum(layered_cum_budgets[var].sel(tcenter=t_layer, method='nearest') for var in rhs_vars)
res = g_tend_layer - rhs_layer

# 画图
plt.figure(figsize=(10,5))
plt.plot(g_tend_layer.time, g_tend_layer, label='g_tend')
plt.plot(rhs_layer.time, rhs_layer, label='rhs (sum of other terms)')
plt.plot(res.time, res, label='residual', linestyle='--')
plt.axhline(0, color='k', linestyle=':')
plt.xlabel('time')
plt.ylabel('value (oC/s)')
plt.title(f'Check g_tend vs RHS at tcenter ~ {t_layer}°C')
plt.legend()
plt.show()
