In [19]:
import numpy as np
import xarray as xr
from concurrent.futures import ProcessPoolExecutor, as_completed
# from dask import delayed, compute
# from dask.distributed import get_client, default_client, LocalCluster, Client
from pyTMD.io import ATLAS
from src.model_utils import get_current_model
from src.pytmd_utils import read_netcdf_grid
maxWorkers = 6

In [2]:
BATHY_gridfile = '/home/bioer/python/tide/data_src/TPXO9_atlas_v5/grid_tpxo9_atlas_30_v5.nc'
tpxo_model_directory = '/home/bioer/python/tide/data_src'
tpxo_model_format = 'netcdf'
tpxo_compressed = False
tpxo_model_name = 'TPXO9-atlas-v5'
tpxo_model = get_current_model(
    tpxo_model_name, tpxo_model_directory, tpxo_model_format, tpxo_compressed)


In [4]:
def recompute_na_points(coord, lonz, latz, bathy_mask, bathy_data):
    ilat_idx, ilon_idx = coord

    if bathy_mask[ilat_idx, ilon_idx] or bathy_data[ilat_idx, ilon_idx] <= 0.0:  # If it's a land point or depth <= 0
        return None

    ilon = lonz[ilon_idx]
    ilat = latz[ilat_idx]

    results = {}
    for var_type in ['u', 'v']:
        amp, ph, _, _ = ATLAS.extract_constants(
            np.atleast_1d(ilon), np.atleast_1d(ilat),
            tpxo_model.grid_file,
            tpxo_model.model_file[var_type], type=var_type, method='spline',
            scale=tpxo_model.scale, compressed=tpxo_model.compressed
        )
        results[f"{var_type}_amp"] = amp
        results[f"{var_type}_ph"] = ph
    
    return results

In [3]:
input_file = "tpxo9_fillna01.zarr"
lonz, latz, bathy_z = read_netcdf_grid(BATHY_gridfile, variable='z')


In [5]:
bathy_mask = bathy_z.mask
bathy_data = bathy_z.data
print(bathy_mask.shape)


(5401, 10800)


In [6]:
ds = xr.open_zarr(input_file)

In [7]:
if True:
    # Check for NaNs in the four variables
    coords_to_recompute = set()
    variables = ['u_amp', 'v_amp', 'u_ph', 'v_ph']

    # nan_loc1 = set(map(tuple, np.argwhere(np.isnan(ds['u_amp'].values).any(axis=-1))))
    ## it seems cause memory crash? #nan_loc1.update(map(tuple, np.argwhere(np.isnan(ds['u_ph'].values).any(axis=-1))))    
    # nan_loc2 = set(map(tuple, np.argwhere(np.isnan(ds['v_amp'].values).any(axis=-1))))
    #### nan_loc2.update(map(tuple, np.argwhere(np.isnan(ds['v_ph'].values).any(axis=-1))))
    #intersecting_nans = nan_loc1.intersection(nan_loc2)
    #for ilat_idx, ilon_idx in intersecting_nans:
    for var in ['u_amp', 'v_amp']:
        print("Now process var to find na: ", var)
        ## it's slow # nan_locs = np.argwhere(np.isnan(ds[var].values))
        nan_locs = np.argwhere(np.isnan(ds[var].values).any(axis=-1))
        #It will get 204969 points if scan u_amp, v_amp, u_ph, v_ph     
        for loc in nan_locs:
            ilat_idx, ilon_idx = loc
            if not bathy_mask[ilat_idx, ilon_idx] and bathy_data[ilat_idx, ilon_idx] > 0.0:
                coords_to_recompute.add((ilat_idx, ilon_idx))
                #print(coords_to_recompute)
            #print(len(coords_to_recompute))
    # Parallelize the re-computation using ProcessPoolExecutor


Now process var to find na:  u_amp
Now process var to find na:  v_amp


In [8]:
total_points = len(coords_to_recompute)
print(f"Total points to process: {total_points}")

Total points to process: 178275


In [9]:
print(list(coords_to_recompute)[0:10])

[(3363, 8087), (4419, 10577), (1846, 10083), (4477, 9801), (3142, 6398), (3350, 6347), (1180, 7778), (1708, 10149), (1749, 10150), (4440, 8421)]


In [21]:
#NA find {(140, 6102), (138, 6232), (139, 6260), (140, 6010), .....} is huge about 204969 points.
ilat_idx = 1846 #5400 #138                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          r33333333333333333333333333333333333333333333333333333333333333333333333333333333333ilat_idx = 5400 #138
ilon_idx = 10083 #8455 #6232
print(ds["u_amp"].isel(lat=ilat_idx, lon=ilon_idx).values)
print(ds["v_ph"].isel(lat=ilat_idx, lon=ilon_idx).values)
print(ds.coords['lon'][ilon_idx], ds.coords['lat'][ilat_idx])
print(lonz[ilon_idx]) #check equality
print(latz[ilat_idx])
print(bathy_mask[ilat_idx, ilon_idx])
print(bathy_data[ilat_idx, ilon_idx])

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]
[113.33001626 132.03097175 160.45599243 162.98771133  67.40618069
  76.1953091  157.48301599  92.99924785  89.2864663  246.77373432
 348.21351821 195.8106125   37.38540242 258.25494763  63.8474404 ]
<xarray.DataArray 'lon' ()>
array(336.13332953)
Coordinates:
    lon      float64 336.1 <xarray.DataArray 'lat' ()>
array(-28.46666779)
Coordinates:
    lat      float64 -28.47
336.1333295343293
-28.466667792594464
False
4866.6201171875


In [22]:
ampu, phu, _, _ = ATLAS.extract_constants(
        np.atleast_1d(lonz[ilon_idx]), np.atleast_1d(latz[ilat_idx]),
        tpxo_model.grid_file,
        tpxo_model.model_file['u'], type='u', method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed
    )
print(ampu, phu)

[[3.5488494946943145 15.736559980206401 3.299996247556154
  9.913170054534874 39.58627119252877 176.34961057860122
  1.0585631069753543 65.29188093555867 17.544279539809963
  1.455588020995228 0.7314513389301621 0.43110939497561923
  7.462967894130535 1.4564532904016338 0.6295665686549752]] [[121.39103966566903 142.7525928986673 182.62312914178708
  191.62518186408047 322.908326448071 325.82775078991494
  259.1186728646128 339.09686857352546 332.7412375319506
  203.10672985887192 291.36801891766413 164.02271721943578
  291.60835961417854 205.23641016283855 10.350843474868027]]


In [23]:
ampv, phv, _, _ = ATLAS.extract_constants(
        np.atleast_1d(lonz[ilon_idx]), np.atleast_1d(latz[ilat_idx]),
        tpxo_model.grid_file,
        tpxo_model.model_file['v'], type='v', method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed
    )
print(ampv, phv)

[[8.668813118062117 35.76033562499962 8.683169526784722 23.45806202528826
  47.0764734010927 213.35622310706765 0.25969733329667843
  77.97409128150774 21.970233688959325 1.4524161022783755
  0.7655314554124386 0.5279226470822542 6.575219141717764
  2.862807843370184 0.5041899446945962]] [[113.33001625627412 132.03097175093149 160.45599243485512
  162.9877113337471 67.40618069237203 76.19530909747947 157.4830159948237
  92.9992478520575 89.28646630226528 246.77373431721838
  348.21351820971336 195.81061250443756 37.385402422090806
  258.2549476327405 63.84744039893113]]


In [24]:
filtered_coords = set()
for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx>=1 and ilon_idx>=1 and ilat_idx<=len(latz)-2 and ilon_idx<=len(lonz)-2:
        neighbors = [
            (ilat_idx+1, ilon_idx),
            (ilat_idx, ilon_idx+1),
            (ilat_idx+1, ilon_idx+1)
        ]
    if all(neighbor in coords_to_recompute for neighbor in neighbors):
        filtered_coords.add((ilat_idx, ilon_idx))
        for neighbor in neighbors:
            filtered_coords.add(neighbor)

#coords_to_recompute = filtered_coords
print('After removing: ', len(filtered_coords))

After removing:  145561


In [None]:
filtered_coord3 = set()
for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx <= len(latz) - 3 and ilon_idx <= len(lonz) - 3:
        neighbors = [
            (ilat_idx+1, ilon_idx), 
            (ilat_idx+2, ilon_idx),
            (ilat_idx, ilon_idx+1), 
            (ilat_idx, ilon_idx+2),
            (ilat_idx+1, ilon_idx+1),
            (ilat_idx+1, ilon_idx+2),
            (ilat_idx+2, ilon_idx+1),
            (ilat_idx+2, ilon_idx+2)
        ]
        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            print(neighbors)
            filtered_coord3.add((ilat_idx, ilon_idx))
            for neighbor in neighbors:
                filtered_coord3.add(neighbor)

print('After removing: ', len(filtered_coord3))                

In [29]:
#try one neibors to extract constants: [(1757, 10166), (1758, 10166), (1756, 10167), (1756, 10168), (1757, 10167), (1757, 10168), (1758, 10167), (1758, 10168)]
lon_chunk = lonz[10166:10169]
lat_chunk = latz[1756:1759]
lon_grid, lat_grid = np.meshgrid(lon_chunk, lat_chunk)
mampu, mphu, mD, mc = ATLAS.extract_constants(
    lon_grid.ravel(), lat_grid.ravel(),
    tpxo_model.grid_file,
    tpxo_model.model_file['u'], type='u', method='spline',
    scale=tpxo_model.scale, compressed=tpxo_model.compressed)
print(mampu, mphu)

[[2.9675224745830766 12.431708000569952 3.0023071293390213
  8.662898754348872 48.42061517276824 219.1841754005411
  1.0730080996791878 83.86950278506212 22.879305544359003
  1.6097733584708138 0.7168854566849281 0.5090492219958825
  8.92096394063201 3.4582581658987905 1.9584240253976168]
 [2.9579358800937934 12.38904314369 2.9940437123012464 8.644111979565274
  48.5142578990052 219.67412312344175 1.0743284565403532 84.0675180084725
  22.937657823537002 1.6052284652899584 0.7154738792649525
  0.5071281581946608 8.938617793078416 3.447970487380112
  1.9736361073348159]
 [2.947794223297378 12.344665388729778 2.9851795650747603
  8.623317657073434 48.60567803063941 220.1546347792494 1.075655369924026
  84.2614119476745 22.994910718999964 1.6004448954366146
  0.7139589904604071 0.5052023780197574 8.955756457211674
  3.4357823124863045 1.988885197272386]
 [2.9685420578616895 12.446326690007595 2.99938513612234
  8.659415466701299 48.360135716488905 218.92991769346906
  1.072960464726159 83.

In [8]:
filtered_coord5 = set()

for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx <= len(latz) - 5 and ilon_idx <= len(lonz) - 5:
        is_square = True
        for i in range(5):
            for j in range(5):
                if (ilat_idx + i, ilon_idx + j) not in coords_to_recompute:
                    is_square = False
                    break
            if not is_square:
                break
        if is_square:
            for i in range(5):
                for j in range(5):
                    filtered_coord5.add((ilat_idx + i, ilon_idx + j))

#coords_to_recompute = filtered_coord5
print('After removing: ', len(filtered_coord5))   

After removing:  135860


In [11]:
def filter_and_form_cluster5(coords_to_recompute):
    filtered_coords = set()
    clusters = []

    for ilat_idx, ilon_idx in coords_to_recompute:
        # Creating a list for all 5x5 neighbors of the current point
        neighbors = [(ilat_idx + dlat, ilon_idx + dlon) for dlat in range(5) for dlon in range(5)]
        
        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            clusters.append(neighbors)
            for neighbor in neighbors:
                filtered_coords.add(neighbor)

    return clusters, filtered_coords

In [13]:
f5_cluster, f5_coords = filter_and_form_cluster5(coords_to_recompute)


In [14]:
print(len(f5_coords))
print(f5_cluster[0:10])
print(list(f5_coords)[0:10])

135860
[[(3363, 8087), (3363, 8088), (3363, 8089), (3363, 8090), (3363, 8091), (3364, 8087), (3364, 8088), (3364, 8089), (3364, 8090), (3364, 8091), (3365, 8087), (3365, 8088), (3365, 8089), (3365, 8090), (3365, 8091), (3366, 8087), (3366, 8088), (3366, 8089), (3366, 8090), (3366, 8091), (3367, 8087), (3367, 8088), (3367, 8089), (3367, 8090), (3367, 8091)], [(1846, 10083), (1846, 10084), (1846, 10085), (1846, 10086), (1846, 10087), (1847, 10083), (1847, 10084), (1847, 10085), (1847, 10086), (1847, 10087), (1848, 10083), (1848, 10084), (1848, 10085), (1848, 10086), (1848, 10087), (1849, 10083), (1849, 10084), (1849, 10085), (1849, 10086), (1849, 10087), (1850, 10083), (1850, 10084), (1850, 10085), (1850, 10086), (1850, 10087)], [(3142, 6398), (3142, 6399), (3142, 6400), (3142, 6401), (3142, 6402), (3143, 6398), (3143, 6399), (3143, 6400), (3143, 6401), (3143, 6402), (3144, 6398), (3144, 6399), (3144, 6400), (3144, 6401), (3144, 6402), (3145, 6398), (3145, 6399), (3145, 6400), (3145, 640

In [17]:
chunk = form_chunk_from_cluster(f5_cluster[0])
print(chunk)

(3363, 3368, 8087, 8092)


In [18]:
tampu, tphu, tD, tc = process_chunk(chunk, lonz, latz, tpxo_model, 'u')
print(tampu)
print(tphu)
print(tD)
print(tc)

[[5.950931125216676 89.30162217881436 56.37364908853846
  151.88024088540803 19.19506429036349 38.92646484374779
  3.5667361789277487 65.68881835937127 19.09810926649197
  3.0912261962888867 0.9493988037108835 1.0778349982367008
  3.247981092664746 0.3008219401041496 0.06758624712625755]
 [2.9067296567171055 81.07855489979771 54.68179453974286
  131.57834791100788 21.312708315642382 44.81933593750083
  3.529279294221363 61.15592094089788 17.152855914572328
  3.4695713209069936 1.0481101326320676 1.1309049855108055
  2.902743795643736 0.12152543275252858 0.02173913043478301]
 [2.329096815321225 81.37521701389043 55.947791883681624
  125.87666015625238 23.636924913194893 52.79497612847323
  3.8653411865235108 60.69475368923726 16.632651095920455
  3.942252943250943 1.2178203158908651 1.2172626919216811
  2.837687344021321 0.024845200114780475 0.0]
 [2.5441933964075667 83.63931925912921 56.17409146769663
  127.20907215589887 24.643626009480336 58.02881408005618
  4.022409160485428 60.9861

In [21]:
remaining_points = coords_to_recompute - f5_coords
if remaining_points:
    print("Remaining points at final: ", len(remaining_points))

Remaining points at final:  42415


In [15]:
def process_chunk(cluster, lonz, latz, tpxo_model, var_type, cluster_idx, cluster_num):
    start_lat, end_lat, start_lon, end_lon = cluster
    lon_chunk = lonz[start_lon:end_lon]
    lat_chunk = latz[start_lat:end_lat]
    lon_grid, lat_grid = np.meshgrid(lon_chunk, lat_chunk)

    amp, ph, D, c = ATLAS.extract_constants(
        lon_grid.ravel(), lat_grid.ravel(),
        tpxo_model.grid_file,
        tpxo_model.model_file[var_type], type=var_type, method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed)

    # reshape back amp and ph
    amp = amp.reshape((end_lat - start_lat, end_lon - start_lon, -1))
    ph = ph.reshape((end_lat - start_lat, end_lon - start_lon, -1))
    print(f"Cluster index: {cluster_idx}/{cluster_num} for variable: {var_type}")

    return (cluster_idx, cluster_num, start_lat, end_lat, start_lon, end_lon, var_type, amp, ph)


def filter_and_form_clusters(coords_to_recompute):
    coords_to_recompute = set(coords_to_recompute)  # Ensure it's a set for efficient removal
    clusters = []

    while coords_to_recompute:
        ilat_idx, ilon_idx = next(iter(coords_to_recompute))  # Take one coord from the set without removing it

        neighbors = [(ilat_idx + dlat, ilon_idx + dlon) for dlat in range(5) for dlon in range(5)]

        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            start_lat, start_lon = min(neighbors, key=lambda x: (x[0], x[1]))
            end_lat, end_lon = max(neighbors, key=lambda x: (x[0], x[1]))
            # +1 because we want to include the last point when slicing
            clusters.append((start_lat, end_lat + 1, start_lon, end_lon + 1))

            for neighbor in neighbors:
                coords_to_recompute.discard(neighbor)  # Remove these neighbors from further consideration
        else:
            coords_to_recompute.discard((ilat_idx, ilon_idx))  # Remove the current point if not all its neighbors are in the set

    return clusters


In [None]:
ReComputeCoords = False
input_file = "tpxo9_fillna01.zarr"

if ReComputeCoords:
    ds = xr.open_zarr(input_file)

    # Check for NaNs in the four variables
    coords_to_recompute = set()
    for var in ['u_amp', 'v_amp']:
        print("Now process var to find na: ", var)
        nan_locs = np.argwhere(np.isnan(ds[var].values).any(axis=-1))
        for loc in nan_locs:
            ilat_idx, ilon_idx = loc
            if not bathy_mask[ilat_idx, ilon_idx] and bathy_data[ilat_idx, ilon_idx] > 0.0:
                coords_to_recompute.add((ilat_idx, ilon_idx))



In [16]:
total_points = len(coords_to_recompute)
print(f"Total points to process: {total_points}")

Total points to process: 178275


In [17]:

# Filter coordinates and form 5x5 clusters
clusters = filter_and_form_clusters(coords_to_recompute)
print(clusters)


[(3363, 3368, 8087, 8092), (1846, 1851, 10083, 10088), (3142, 3147, 6398, 6403), (3350, 3355, 6347, 6352), (1180, 1185, 7778, 7783), (1708, 1713, 10149, 10154), (1749, 1754, 10150, 10155), (4440, 4445, 8421, 8426), (5379, 5384, 5408, 5413), (3292, 3297, 6342, 6347), (1858, 1863, 10093, 10098), (3174, 3179, 7784, 7789), (1761, 1766, 10160, 10165), (329, 334, 9076, 9081), (1134, 1139, 7783, 7788), (1703, 1708, 10155, 10160), (5333, 5338, 5413, 5418), (5374, 5379, 5414, 5419), (1812, 1817, 10098, 10103), (1853, 1858, 10099, 10104), (3128, 3133, 7789, 7794), (1715, 1720, 10165, 10170), (324, 329, 9082, 9087), (3120, 3125, 6393, 6398), (1129, 1134, 7789, 7794), (4606, 4611, 8467, 8472), (1865, 1870, 10109, 10114), (4418, 4423, 8416, 8421), (5381, 5386, 8463, 8468), (3336, 3341, 8088, 8093), (3074, 3079, 6398, 6403), (3115, 3120, 6399, 6404), (1771, 1776, 7414, 7419), (4668, 4673, 9695, 9700), (1860, 1865, 10115, 10120), (4413, 4418, 8422, 8427), (3348, 3353, 8098, 8103), (5376, 5381, 8469, 

In [20]:
print(maxWorkers)

6


In [None]:
ReTest = True
if ReTest:
    with ProcessPoolExecutor(max_workers=maxWorkers) as executor:
        futures_list = []
        total_clusters = len(clusters)
        for idx, chunk in enumerate(clusters):
            for var_type in ['u', 'v']:
                future = executor.submit(process_chunk, chunk, lonz, latz, tpxo_model, var_type, idx, total_clusters)
                futures_list.append(future)

        for future in as_completed(futures_list):
            idx_processed, cluster_num, start_lat, end_lat, start_lon, end_lon, var_type_processed, amp, ph = future.result()
            print(f"Processed cluster index: {idx_processed}/{cluster_num} for variable: {var_type_processed}")

            # Refill zarr dataset based on the variable type
            if var_type_processed == 'u':
                ds['u_amp'][start_lat:end_lat, start_lon:end_lon, :] = amp
                ds['u_ph'][start_lat:end_lat, start_lon:end_lon, :] = ph
            else:
                ds['v_amp'][start_lat:end_lat, start_lon:end_lon, :] = amp
                ds['v_ph'][start_lat:end_lat, start_lon:end_lon, :] = ph    

In [None]:
ReSave = True
if ReSave:
    outfile = 'tpxo9.zarr'
    print("Start to re-write zarr dataset...")
    ds.to_zarr(outfile, mode='w')
#ds.close()
