In [1]:
import numpy as np
import xarray as xr
from concurrent.futures import ProcessPoolExecutor, as_completed
# from dask import delayed, compute
# from dask.distributed import get_client, default_client, LocalCluster, Client
from pyTMD.io import ATLAS
from src.model_utils import get_current_model
from src.pytmd_utils import read_netcdf_grid
maxWorkers = 4

In [2]:
BATHY_gridfile = '/home/bioer/python/tide/data_src/TPXO9_atlas_v5/grid_tpxo9_atlas_30_v5.nc'
tpxo_model_directory = '/home/bioer/python/tide/data_src'
tpxo_model_format = 'netcdf'
tpxo_compressed = False
tpxo_model_name = 'TPXO9-atlas-v5'
tpxo_model = get_current_model(
    tpxo_model_name, tpxo_model_directory, tpxo_model_format, tpxo_compressed)


In [3]:
def recompute_na_points(coord, lonz, latz, bathy_mask, bathy_data):
    ilat_idx, ilon_idx = coord

    if bathy_mask[ilat_idx, ilon_idx] or bathy_data[ilat_idx, ilon_idx] <= 0.0:  # If it's a land point or depth <= 0
        return None

    ilon = lonz[ilon_idx]
    ilat = latz[ilat_idx]

    results = {}
    for var_type in ['u', 'v']:
        amp, ph, _, _ = ATLAS.extract_constants(
            np.atleast_1d(ilon), np.atleast_1d(ilat),
            tpxo_model.grid_file,
            tpxo_model.model_file[var_type], type=var_type, method='spline',
            scale=tpxo_model.scale, compressed=tpxo_model.compressed
        )
        results[f"{var_type}_amp"] = amp
        results[f"{var_type}_ph"] = ph
    
    return results

In [4]:
input_file = "tpxo9.zarr"
lonz, latz, bathy_z = read_netcdf_grid(BATHY_gridfile, variable='z')


In [5]:
bathy_mask = bathy_z.mask
bathy_data = bathy_z.data
print(bathy_mask.shape)


(5401, 10800)


In [6]:
ds = xr.open_zarr(input_file)

In [14]:
if True:
    # Check for NaNs in the four variables
    coords_to_recompute = set()
    variables = ['u_amp', 'v_amp', 'u_ph', 'v_ph']

    # nan_loc1 = set(map(tuple, np.argwhere(np.isnan(ds['u_amp'].values).any(axis=-1))))
    ## it seems cause memory crash? #nan_loc1.update(map(tuple, np.argwhere(np.isnan(ds['u_ph'].values).any(axis=-1))))    
    # nan_loc2 = set(map(tuple, np.argwhere(np.isnan(ds['v_amp'].values).any(axis=-1))))
    #### nan_loc2.update(map(tuple, np.argwhere(np.isnan(ds['v_ph'].values).any(axis=-1))))
    #intersecting_nans = nan_loc1.intersection(nan_loc2)
    #for ilat_idx, ilon_idx in intersecting_nans:
    for var in ['u_amp', 'v_amp']:
        print("Now process var to find na: ", var)
        ## it's slow # nan_locs = np.argwhere(np.isnan(ds[var].values))
        nan_locs = np.argwhere(np.isnan(ds[var].values).any(axis=-1))
        #It will get 204969 points if scan u_amp, v_amp, u_ph, v_ph     
        for loc in nan_locs:
            ilat_idx, ilon_idx = loc
            if not bathy_mask[ilat_idx, ilon_idx] and bathy_data[ilat_idx, ilon_idx] > 0.0:
                coords_to_recompute.add((ilat_idx, ilon_idx))
                #print(coords_to_recompute)
            #print(len(coords_to_recompute))
    # Parallelize the re-computation using ProcessPoolExecutor


Now process var to find na:  u_amp
Now process var to find na:  v_amp


In [15]:
total_points = len(coords_to_recompute)
print(f"Total points to process: {total_points}")

Total points to process: 178275


In [20]:
print(list(coords_to_recompute)[0:10])

[(3363, 8087), (4419, 10577), (1846, 10083), (4477, 9801), (3142, 6398), (3350, 6347), (1180, 7778), (1708, 10149), (1749, 10150), (4440, 8421)]


In [21]:
#NA find {(140, 6102), (138, 6232), (139, 6260), (140, 6010), .....} is huge about 204969 points.
ilat_idx = 1846 #5400 #138                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          r33333333333333333333333333333333333333333333333333333333333333333333333333333333333ilat_idx = 5400 #138
ilon_idx = 10083 #8455 #6232
print(ds["u_amp"].isel(lat=ilat_idx, lon=ilon_idx).values)
print(ds["v_ph"].isel(lat=ilat_idx, lon=ilon_idx).values)
print(ds.coords['lon'][ilon_idx], ds.coords['lat'][ilat_idx])
print(lonz[ilon_idx]) #check equality
print(latz[ilat_idx])
print(bathy_mask[ilat_idx, ilon_idx])
print(bathy_data[ilat_idx, ilon_idx])

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]
[113.33001626 132.03097175 160.45599243 162.98771133  67.40618069
  76.1953091  157.48301599  92.99924785  89.2864663  246.77373432
 348.21351821 195.8106125   37.38540242 258.25494763  63.8474404 ]
<xarray.DataArray 'lon' ()>
array(336.13332953)
Coordinates:
    lon      float64 336.1 <xarray.DataArray 'lat' ()>
array(-28.46666779)
Coordinates:
    lat      float64 -28.47
336.1333295343293
-28.466667792594464
False
4866.6201171875


In [22]:
ampu, phu, _, _ = ATLAS.extract_constants(
        np.atleast_1d(lonz[ilon_idx]), np.atleast_1d(latz[ilat_idx]),
        tpxo_model.grid_file,
        tpxo_model.model_file['u'], type='u', method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed
    )
print(ampu, phu)

[[3.5488494946943145 15.736559980206401 3.299996247556154
  9.913170054534874 39.58627119252877 176.34961057860122
  1.0585631069753543 65.29188093555867 17.544279539809963
  1.455588020995228 0.7314513389301621 0.43110939497561923
  7.462967894130535 1.4564532904016338 0.6295665686549752]] [[121.39103966566903 142.7525928986673 182.62312914178708
  191.62518186408047 322.908326448071 325.82775078991494
  259.1186728646128 339.09686857352546 332.7412375319506
  203.10672985887192 291.36801891766413 164.02271721943578
  291.60835961417854 205.23641016283855 10.350843474868027]]


In [23]:
ampv, phv, _, _ = ATLAS.extract_constants(
        np.atleast_1d(lonz[ilon_idx]), np.atleast_1d(latz[ilat_idx]),
        tpxo_model.grid_file,
        tpxo_model.model_file['v'], type='v', method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed
    )
print(ampv, phv)

[[8.668813118062117 35.76033562499962 8.683169526784722 23.45806202528826
  47.0764734010927 213.35622310706765 0.25969733329667843
  77.97409128150774 21.970233688959325 1.4524161022783755
  0.7655314554124386 0.5279226470822542 6.575219141717764
  2.862807843370184 0.5041899446945962]] [[113.33001625627412 132.03097175093149 160.45599243485512
  162.9877113337471 67.40618069237203 76.19530909747947 157.4830159948237
  92.9992478520575 89.28646630226528 246.77373431721838
  348.21351820971336 195.81061250443756 37.385402422090806
  258.2549476327405 63.84744039893113]]


In [24]:
filtered_coords = set()
for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx>=1 and ilon_idx>=1 and ilat_idx<=len(latz)-2 and ilon_idx<=len(lonz)-2:
        neighbors = [
            (ilat_idx+1, ilon_idx),
            (ilat_idx, ilon_idx+1),
            (ilat_idx+1, ilon_idx+1)
        ]
    if all(neighbor in coords_to_recompute for neighbor in neighbors):
        filtered_coords.add((ilat_idx, ilon_idx))
        for neighbor in neighbors:
            filtered_coords.add(neighbor)

#coords_to_recompute = filtered_coords
print('After removing: ', len(filtered_coords))

After removing:  145561


In [None]:
filtered_coord3 = set()
for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx <= len(latz) - 3 and ilon_idx <= len(lonz) - 3:
        neighbors = [
            (ilat_idx+1, ilon_idx), 
            (ilat_idx+2, ilon_idx),
            (ilat_idx, ilon_idx+1), 
            (ilat_idx, ilon_idx+2),
            (ilat_idx+1, ilon_idx+1),
            (ilat_idx+1, ilon_idx+2),
            (ilat_idx+2, ilon_idx+1),
            (ilat_idx+2, ilon_idx+2)
        ]
        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            print(neighbors)
            filtered_coord3.add((ilat_idx, ilon_idx))
            for neighbor in neighbors:
                filtered_coord3.add(neighbor)

print('After removing: ', len(filtered_coord3))                

In [29]:
#try one neibors to extract constants: [(1757, 10166), (1758, 10166), (1756, 10167), (1756, 10168), (1757, 10167), (1757, 10168), (1758, 10167), (1758, 10168)]
lon_chunk = lonz[10166:10169]
lat_chunk = latz[1756:1759]
lon_grid, lat_grid = np.meshgrid(lon_chunk, lat_chunk)
mampu, mphu, mD, mc = ATLAS.extract_constants(
    lon_grid.ravel(), lat_grid.ravel(),
    tpxo_model.grid_file,
    tpxo_model.model_file['u'], type='u', method='spline',
    scale=tpxo_model.scale, compressed=tpxo_model.compressed)
print(mampu, mphu)

[[2.9675224745830766 12.431708000569952 3.0023071293390213
  8.662898754348872 48.42061517276824 219.1841754005411
  1.0730080996791878 83.86950278506212 22.879305544359003
  1.6097733584708138 0.7168854566849281 0.5090492219958825
  8.92096394063201 3.4582581658987905 1.9584240253976168]
 [2.9579358800937934 12.38904314369 2.9940437123012464 8.644111979565274
  48.5142578990052 219.67412312344175 1.0743284565403532 84.0675180084725
  22.937657823537002 1.6052284652899584 0.7154738792649525
  0.5071281581946608 8.938617793078416 3.447970487380112
  1.9736361073348159]
 [2.947794223297378 12.344665388729778 2.9851795650747603
  8.623317657073434 48.60567803063941 220.1546347792494 1.075655369924026
  84.2614119476745 22.994910718999964 1.6004448954366146
  0.7139589904604071 0.5052023780197574 8.955756457211674
  3.4357823124863045 1.988885197272386]
 [2.9685420578616895 12.446326690007595 2.99938513612234
  8.659415466701299 48.360135716488905 218.92991769346906
  1.072960464726159 83.

In [26]:
filtered_coord5 = set()

for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx <= len(latz) - 5 and ilon_idx <= len(lonz) - 5:
        is_square = True
        for i in range(5):
            for j in range(5):
                if (ilat_idx + i, ilon_idx + j) not in coords_to_recompute:
                    is_square = False
                    break
            if not is_square:
                break
        if is_square:
            for i in range(5):
                for j in range(5):
                    filtered_coord5.add((ilat_idx + i, ilon_idx + j))

#coords_to_recompute = filtered_coord5
print('After removing: ', len(filtered_coord5))   

After removing:  135860
