In [None]:
import GMesh
import math
import netCDF4
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%%time
# Read topo data
# Example: cp -n /archive/gold/datasets/topography/GEBCO_2023/GEBCO_2023.nc source_topog.nc
with netCDF4.Dataset('source_topog.nc') as nc:
    topo_lon = nc.variables['lon'][:].filled(0.)
    topo_lat = nc.variables['lat'][:].filled(0.)
    topo_depth = nc.variables['elevation'][:,:].filled(0.)

In [None]:
src_topo_global = GMesh.UniformEDS( topo_lon, topo_lat, topo_depth )
src_topo_global

In [None]:
# Read target mesh
# Example: cp -n /archive/bgr/Datasets/OM5/topo_v4/OM5_025/ocean_hgrid.nc ocean_hgrid.nc
with netCDF4.Dataset('ocean_hgrid.nc') as nc:
    fullG = GMesh.GMesh( lon=nc.variables['x'][::2,::2], lat=nc.variables['y'][::2,::2] )
fullG

In [None]:
GMesh.pfactor( fullG.ni ), GMesh.pfactor( fullG.nj )

In [None]:
def convol( levels, h, f, verbose=False ):
    """Coarsens the product of h*f across all levels"""
    levels[-1].height = ( h * f ).reshape(levels[-1].nj,levels[-1].ni)
    for k in range( len(levels) - 1, 0, -1 ):
        if verbose: print('Coarsening {} -> {}'.format(k,k-1))
        levels[k].coarsenby2( levels[k-1] )
    return levels[0].height

In [None]:
def rough( levels, h, h2min=1.e-7 ):
    """Calculates both mean of H, and variance of H relative to a plane"""
    # Construct weights for moment calculations
    nx = 2**( len(levels) - 1 )
    x = ( np.arange(nx) - ( nx - 1 ) /2 ) * np.sqrt( 12 / ( nx**2 - 1 ) ) # This formula satisfies <x>=0 and <x^2>=1
    X, Y = np.meshgrid( x, x )
    X, Y = X.reshape(1,nx,1,nx), Y.reshape(1,nx,1,nx)
    h = h.reshape(levels[0].nj,nx,levels[0].ni,nx)
    # Now calculate moments
    H2 = convol( levels, h, h ) # mean of h^2
    HX = convol( levels, h, X ) # mean of h * x
    HY = convol( levels, h, Y ) # mean of h * y
    H = convol( levels, h, np.ones((1,nx,1,nx)) ) # mean of h = mean of h * 1
    # The variance of deviations from the plane = <h^2> - <h>^2 - <h*x>^2 - <h*y>^2 given <x>=<y>=0 and <x^2>=<y^2>=1
    return H, H2 - H**2 - HX**2 - HY**2 + h2min

In [None]:
def main(NtileI, NtileJ, max_refinement, write=True, plot=True, filestr="new_topo_OM5_grid"):
    """Main Loop"""
    
    di, dj = fullG.ni // NtileI, fullG.nj // NtileJ
    assert di*NtileI == fullG.ni
    assert dj*NtileJ == fullG.nj
    print('window size dj,di =',dj,di,'full model nj,ni=',fullG.nj, fullG.ni)
    Hcnt = np.zeros((fullG.nj, fullG.ni)) # Diagnostic: counting which cells we are working on
    Htarg, H2targ = np.zeros((fullG.nj, fullG.ni)), np.zeros((fullG.nj, fullG.ni))
    gtic = GMesh.GMesh._toc(None,"")
    for j in range( NtileJ ): # 23 / 9
        csj, sj = slice( j*dj, (j+1)*dj ), slice( j*dj, (j+1)*dj+1 )
        for i in range( NtileI ): # 0 / 6
            csi, si = slice( i*di, (i+1)*di ), slice( i*di, (i+1)*di+1 ) # Slices of target grid
            Hcnt[csj,csi] = Hcnt[csj,csi] + 1 # Diagnostic: counting which cells we are working on
            G = GMesh.GMesh( lon=fullG.lon[sj,si], lat=fullG.lat[sj,si] )
            print('J,I={},{} {:.1f}%, {}\n   window lon={}:{}, lat={}:{}\n   jslice={}, islice={}'.format( \
                j, i, 100*(j*NtileI+i)/(NtileI*NtileJ), G, G.lon.min(), G.lon.max(), G.lat.min(), G.lat.max(), sj, si ))
            # This recursively refines the mesh until some criteria is met ...
            levels = G.refine_loop( src_topo_global, resolution_limit=False, fixed_refine_level=max_refinement, timers=False )
            # Use nearest neighbor topography to populate the finest grid
            levels[-1].project_source_data_onto_target_mesh( src_topo_global )
            # Now recursively coarsen
            h, h2 = rough( levels, levels[-1].height )
            # Store window in final array
            Htarg[csj,csi] = h
            H2targ[csj,csi] = h2
    GMesh.GMesh._toc(gtic,"Whole workflow")
    print( Hcnt.min(), Hcnt.max(), '<-- should both be 1 for full model' )

    if write is True:
        outfile = f"{filestr}_r{max_refinement}_{NtileI}x{NtileJ}.nc"
        with netCDF4.Dataset(outfile,'w','clobber') as nc:
            nx = nc.createDimension('nx', Htarg.shape[1])
            ny = nc.createDimension('ny', Htarg.shape[0])
            ntiles = nc.createDimension('ntiles', 1)
            z = nc.createVariable('depth', float, ('ny','nx') )
            z2 = nc.createVariable('h2', float, ('ny','nx') )
            z[:,:] = -Htarg[:,:]
            z2[:,:] = H2targ[:,:]
        print(f"** wrote {outfile} **")

    if plot is True:
        plt.pcolormesh( fullG.lon, fullG.lat, Htarg ); plt.colorbar();

In [None]:
%%time
# Test of three refinements - useful for reproducibility
# order of arguments: NtileI, NtileJ, max_refinement
main(1, 9, 3)

In [None]:
%%time
# n=7 refinements
# order of arguments: NtileI, NtileJ, max_refinement
main(3*2, 43*3*3, 7)

In [None]:
%%time
# n=8 refinements
# order of arguments: NtileI, NtileJ, max_refinement
main(3*2*2*2, 43*3*3, 8)

In [None]:
# Other refinement options:
# NtileI, NtileJ, max_refinement = 1, 9, 3 
# NtileI, NtileJ, max_refinement = 3, 43*3*3, 0
# NtileI, NtileJ, max_refinement = 3, 43, 5
# NtileI, NtileJ, max_refinement = 3, 43*3, 6
# NtileI, NtileJ, max_refinement = 3*2, 43*3*3, 7
# NtileI, NtileJ, max_refinement = 3*2*2*2, 43*3*3, 8