---
# Create ESMF mesh file for running WW3 for CESM3 LGM
- Author: Jiang Zhu (jiangzhu@ucar.edu)
- Tools used
  - [pop_tools](https://pop-tools.readthedocs.io/en/latest/) to get gx3v7 grid
  - xesmf to regrid MOM6 LGM topo into gx3v7
  - This step depends on the LGM Topo file from create_mom6_bathy.ipynb
  - UXarray for visualization of the mesh file
---

In [1]:
import datetime
import os

import numpy as np
import xarray as xr
import uxarray as ux
import xesmf

import pop_tools

import warnings
warnings.filterwarnings('ignore')

  from pkg_resources import DistributionNotFound, get_distribution


---
## Input files
- WW3 mesh file for preindustrial, elementMask of which will be updated
- LGM topography file for MOM6 to get an LGM land-sea mask
- Standard gx3v7 grid for regridding

In [2]:
ww3_mesh_pre = '/glade/campaign/cesm/cesmdata/inputdata/share/meshes/wgx3v7_240327_ESMFmesh.nc'
ds_mesh_pre = xr.open_dataset(ww3_mesh_pre)
ds_mesh_pre

In [3]:
mom_topo_lgm = '/glade/work/jiangzhu/data/inputdata/mom/tx2_3v2/ocean_topo_tx2_3v2_240501_21ka_260119.nc'
ds_mom_lgm = xr.open_dataset(mom_topo_lgm).rename({'y': 'lat', 'x': 'lon'})
ds_mom_lgm

In [4]:
ds_gx3 = pop_tools.get_grid('POP_gx3v7')
ds_gx3['lat'] = ds_gx3.TLAT
ds_gx3['lon'] = ds_gx3.TLONG
ds_gx3

---
## Output mesh file
- We only need to update the land-sea mask, `elementMask`

In [5]:
today = datetime.date.today().strftime("%y%m%d")
print(today)

filename = os.path.basename(ww3_mesh_pre)
new_filename = filename.replace(".nc", f"_21ka_{today}.nc")

work_dir = '/glade/work/jiangzhu/data/inputdata/mom/tx2_3v2'
ww3_mesh_lgm = f"{work_dir}/{new_filename}"
print(ww3_mesh_lgm)

260119
/glade/work/jiangzhu/data/inputdata/mom/tx2_3v2/wgx3v7_240327_ESMFmesh_21ka_260119.nc


---
## Step 1: Use LGM topo to create a land-sea mask to update that in the preindustrial mesh

### Regrid MOM6 `depth` to gx3v7

In [6]:
%%time

regridder = xesmf.Regridder(
    ds_mom_lgm,
    ds_gx3,
    method="bilinear",
    periodic=True)

depth_lgm = regridder(ds_mom_lgm.depth)
depth_lgm

CPU times: user 2.27 s, sys: 84.2 ms, total: 2.35 s
Wall time: 2.6 s


### Update `elementMask` in the preindustrial mesh file

In [7]:
mask_pre = ds_mesh_pre.elementMask
mask_lgm = np.where(depth_lgm > 0, 1, 0).flatten().astype(np.int32)

ds_mesh_lgm = ds_mesh_pre.copy(deep=True)
ds_mesh_lgm.elementMask.data = np.where(mask_lgm == 0, 0, mask_pre)

ds_mesh_lgm.attrs['Title']  = 'ESMF mesh for running WW3 in gx3v7 for CESM3 LGM'
ds_mesh_lgm.attrs['Author'] = 'Jiang Zhu (jiangzhu@ucar.edu)'
ds_mesh_lgm.attrs["Source_topo_file"] = mom_topo_lgm
ds_mesh_lgm.attrs['Script'] = 'create_wave_mesh.ipynb'
ds_mesh_lgm.attrs['More_info'] = 'https://github.com/NCAR/paleowg-recipes/cesm3_lgm'
ds_mesh_lgm.attrs["Date_created"] = datetime.datetime.now().isoformat()

ds_mesh_lgm.to_netcdf(ww3_mesh_lgm, format="NETCDF3_64BIT")

---
## Step 2: Use UXarray to make plot to check results

In [8]:
ds_mesh_pre = ux.open_dataset(ww3_mesh_pre, ww3_mesh_pre)
ds_mesh_lgm = ux.open_dataset(ww3_mesh_lgm, ww3_mesh_lgm)

p1 = ds_mesh_pre.elementMask.plot().opts(
    height=400, width=600, title='mask in PI mesh')

p2 = ds_mesh_lgm.elementMask.plot().opts(
    height=400, width=600, title='mask in LGM mesh')

p3 = (ds_mesh_lgm.elementMask - ds_mesh_pre.elementMask).plot(
    cmap='BuRd_r').opts(height=400, width=600, title='diff')

p = p1 + p2 + p3
p.cols(2)