In [1]:
import numpy as np
import xarray as xr
from concurrent.futures import ProcessPoolExecutor, as_completed
# from dask import delayed, compute
# from dask.distributed import get_client, default_client, LocalCluster, Client
from pyTMD.io import ATLAS
from src.model_utils import get_current_model
from src.pytmd_utils import read_netcdf_grid
maxWorkers = 6

In [2]:
BATHY_gridfile = '/home/bioer/python/tide/data_src/TPXO9_atlas_v5/grid_tpxo9_atlas_30_v5.nc'
tpxo_model_directory = '/home/bioer/python/tide/data_src'
tpxo_model_format = 'netcdf'
tpxo_compressed = False
tpxo_model_name = 'TPXO9-atlas-v5'
tpxo_model = get_current_model(
    tpxo_model_name, tpxo_model_directory, tpxo_model_format, tpxo_compressed)


In [3]:
def recompute_na_points(coord, lonz, latz, bathy_mask, bathy_data):
    ilat_idx, ilon_idx = coord

    if bathy_mask[ilat_idx, ilon_idx] or bathy_data[ilat_idx, ilon_idx] <= 0.0:  # If it's a land point or depth <= 0
        return None

    ilon = lonz[ilon_idx]
    ilat = latz[ilat_idx]

    results = {}
    for var_type in ['u', 'v']:
        amp, ph, _, _ = ATLAS.extract_constants(
            np.atleast_1d(ilon), np.atleast_1d(ilat),
            tpxo_model.grid_file,
            tpxo_model.model_file[var_type], type=var_type, method='spline',
            scale=tpxo_model.scale, compressed=tpxo_model.compressed
        )
        results[f"{var_type}_amp"] = amp
        results[f"{var_type}_ph"] = ph
    
    return results

In [4]:
input_file = "../data/tpxo9.zarr"


In [5]:
lonz, latz, bathy_z = read_netcdf_grid(BATHY_gridfile, variable='z')


In [6]:
bathy_mask = bathy_z.mask
bathy_data = bathy_z.data
print(bathy_mask.shape)


(5401, 10800)


In [7]:
ds = xr.open_zarr(input_file, chunks='auto', decode_times=False, consolidated=True) 
print(ds['lat'].values[2700])
#ds = ds.chunk({'lat': 113, 'lon': 113, 'constituents': 8}) 
print(ds.info)

0.0
<bound method Dataset.info of <xarray.Dataset>
Dimensions:       (constituents: 15, lat: 5401, lon: 10800)
Coordinates:
  * constituents  (constituents) <U3 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 -90.0 -89.97 -89.93 -89.9 ... 89.93 89.97 90.0
  * lon           (lon) float64 0.03333 0.06667 0.1 0.1333 ... 359.9 360.0 360.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>

In [30]:
ds['lat'].values[2700] = 0
ds = ds.sortby('lat')
print(ds)

<xarray.Dataset>
Dimensions:       (constituents: 15, lat: 5401, lon: 10800)
Coordinates:
  * constituents  (constituents) object 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 -90.0 -89.97 -89.93 -89.9 ... 89.93 89.97 90.0
  * lon           (lon) float64 0.03333 0.06667 0.1 0.1333 ... 359.9 360.0 360.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>


In [8]:
print(lonz)
print(latz)
print(ds['lon'].values)
print(ds['lat'].values)
print(len(lonz))
print(len(latz))
print(len(ds['lon'].values))
print(len(ds['lat'].values))

[3.33333340e-02 6.66666670e-02 9.99999999e-02 ... 3.59933329e+02
 3.59966663e+02 3.59999996e+02]
[-90.00000356 -89.96667023 -89.93333689 ...  89.93333689  89.96667023
  90.00000356]
[3.33333340e-02 6.66666670e-02 9.99999999e-02 ... 3.59933329e+02
 3.59966663e+02 3.59999996e+02]
[-90.00000356 -89.96667023 -89.93333689 ...  89.93333689  89.96667023
  90.00000356]
10800
5401
10800
5401


In [41]:
#ds['lon'].values = lonz
#ds['lat'].values = latz


In [None]:
#ds_new = ds.sel(lat=latz, lon=lonz, method='nearest')
#print(ds_new)

In [7]:
#del ds
#ds = ds_new
print(ds)

<xarray.Dataset>
Dimensions:       (constituents: 15, lat: 5401, lon: 10800)
Coordinates:
  * constituents  (constituents) object 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 -90.0 -89.97 -89.93 -89.9 ... 89.93 89.97 90.0
  * lon           (lon) float64 0.03333 0.06667 0.1 0.1333 ... 359.9 360.0 360.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>


In [10]:
print(ds.coords["constituents"].values)

['q1' 'o1' 'p1' 'k1' 'n2' 'm2' 's1' 's2' 'k2' 'm4' 'ms4' 'mn4' '2n2' 'mf'
 'mm']


In [9]:
import zarr

store = zarr.open(input_file, mode='r')  # Open in read mode

# List variables with object dtype
for var_name in store.array_keys():
    if store[var_name].dtype == object:
        print(f"{var_name} has dtype=object")

constituents has dtype=object


In [20]:
#Rescale amp by 0.01 https://github.com/tsutterley/pyTMD/discussions/241
Rescale = False
if Rescale:
    ds = xr.open_zarr(input_file, chunks={'lat': 113, 'lon': 113, 'constituents': 8})
    # Scale the needed variables
    variables_to_scale = ['u_amp', 'v_amp']
    ds_scaled = ds.assign({var: ds[var] * 0.01 for var in variables_to_scale})
    ds_scaled['constituents'] = np.array(['q1', 'o1', 'p1', 'k1', 'n2', 'm2', 's1', 's2', 'k2', 'm4', 'ms4', 'mn4', '2n2', 'mf', 'mm'], dtype=str)

In [21]:
# Save the corrected dataset
if Rescale:
    corrected_file_path = '../data/tpxo9_fillna05.zarr'  # New file path to save the corrected dataset
    ds_scaled.to_zarr(corrected_file_path, mode='w')  # Use mode='w' to overwrite if using the same file path
    print(f"Dataset scaled and saved to {corrected_file_path}")    


Dataset scaled and saved to ../data/tpxo9_fillna05.zarr


In [35]:
ds = xr.open_zarr(input_file, chunks='auto', decode_times=False, consolidated=True) 

In [8]:
All_NA_CONDITION = False
if True:
    # Check for NaNs in the four variables
    coords_to_recompute = set()
    variables = ['u_amp', 'v_amp', 'u_ph', 'v_ph']

    # nan_loc1 = set(map(tuple, np.argwhere(np.isnan(ds['u_amp'].values).any(axis=-1))))
    ## it seems cause memory crash? #nan_loc1.update(map(tuple, np.argwhere(np.isnan(ds['u_ph'].values).any(axis=-1))))    
    # nan_loc2 = set(map(tuple, np.argwhere(np.isnan(ds['v_amp'].values).any(axis=-1))))
    #### nan_loc2.update(map(tuple, np.argwhere(np.isnan(ds['v_ph'].values).any(axis=-1))))
    #intersecting_nans = nan_loc1.intersection(nan_loc2)
    #for ilat_idx, ilon_idx in intersecting_nans:
    for var in ['u_amp', 'v_ph']:
        print("Now process var to find na: ", var)
        ## it's slow # nan_locs = np.argwhere(np.isnan(ds[var].values))
        if All_NA_CONDITION:
            nan_locs = np.argwhere(np.isnan(ds[var].values).all(axis=-1))
        else:    
            nan_locs = np.argwhere(np.isnan(ds[var].values).any(axis=-1))
        #It will get 204969 points if scan u_amp, v_amp, u_ph, v_ph     
        for loc in nan_locs:
            ilat_idx, ilon_idx = loc
            if not bathy_mask[ilat_idx, ilon_idx] and bathy_data[ilat_idx, ilon_idx] > 0.0:
                coords_to_recompute.add((ilat_idx, ilon_idx))
                #print(coords_to_recompute)
            #print(len(coords_to_recompute))
    # Parallelize the re-computation using ProcessPoolExecutor


Now process var to find na:  u_amp
Now process var to find na:  v_ph


In [32]:
#all_na_list_to_recompute = np.copy(list(coords_to_recompute))
#print(len(all_na_list_to_recompute))

In [9]:
#Total points to process: 178275 before first-time 5x5 neighbors with NA recomputation
#Total points to process: 109452 before first-time 4x4 neighbors NA recomputation and if All_NA_CONDITION = True -> 100388(fillna03) -> 89334(fillna04) -> 84282(fillna05) -> 82502(06, with extrapolate) -> 79654(07, 2*3 neighbors)
#79654 -> 73114 (08, 1*3 neighbors)
#All_na_condition set False 508981 pts -> 428631 pts(09, 5x5) -> 380901(10, 2x5) -> 340137(11, 2x3) -> 308379(12, 3x2) -> 285428 -> 273804(13, 2*4)
# -> 231480 (14, 2x2)
total_points = len(coords_to_recompute)
print(f"Total points to process: {total_points}")

Total points to process: 231480


In [25]:
print(list(coords_to_recompute)[0:10])

[(2988, 8964), (3650, 7363), (4046, 947), (4419, 10577), (5071, 8288), (4517, 532), (5009, 7410), (3574, 8307), (3789, 3765), (4040, 1132)]


In [50]:
#NA find {(140, 6102), (138, 6232), (139, 6260), (140, 6010), .....} is huge about 204969 points.
ilat_idx = 4601 #1846 #5400 #138                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          r33333333333333333333333333333333333333333333333333333333333333333333333333333333333ilat_idx = 5400 #138
ilon_idx = 9278 #10083 #8455 #6232
print(ds["u_amp"].isel(lat=ilat_idx, lon=ilon_idx).values)
print(ds["v_ph"].isel(lat=ilat_idx, lon=ilon_idx).values)
print(ds.coords['lon'][ilon_idx], ds.coords['lat'][ilat_idx])
print(lonz[ilon_idx]) #check equality
print(latz[ilat_idx])
print(bathy_mask[ilat_idx, ilon_idx])
print(bathy_data[ilat_idx, ilon_idx])

[2.37540944e+01 2.37350708e+02 1.59934387e+02 5.17129517e+02
 1.40964583e+03 1.43437513e+04 2.22089208e+01 2.10834782e+03
 7.46692790e+02 5.22872009e+01 1.24571482e+01 1.24571482e+01
 1.21741130e+02 3.90156651e+00 1.13345591e+00]
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]
<xarray.DataArray 'lon' ()>
array(309.2999965)
Coordinates:
    lon      float64 309.3 <xarray.DataArray 'lat' ()>
array(63.36666917)
Coordinates:
    lat      float64 63.37
309.2999965043527
63.36666917297668
False
6.0


In [46]:
ampu, phu, _, _ = ATLAS.extract_constants(
        np.atleast_1d(lonz[ilon_idx]), np.atleast_1d(latz[ilat_idx]),
        tpxo_model.grid_file,
        tpxo_model.model_file['u'], type='u', method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed
    )
print(ampu, phu)

[[6.471731252495115 35.33086601374958 16.516280589743747
  41.467513024211506 142.83056067766648 664.3835523804238
  2.444686523899555 234.189589513611 67.98351421495443 87.0246193313029
  63.81208862198095 35.57978240963184 14.91632397993586
  0.5217650702235797 0.1304412675558949]] [[299.92393051356146 311.93218171600927 331.7145695996912
  312.8010739859887 300.80447167117035 325.3005762992059
  260.7889751355719 352.1406908598158 348.51758798627856
  357.0142950755634 55.725218494572914 289.09922132068266
  285.8593620690221 90.00000250447816 90.00000250447816]]


In [23]:
ampv, phv, _, _ = ATLAS.extract_constants(
        np.atleast_1d(lonz[ilon_idx]), np.atleast_1d(latz[ilat_idx]),
        tpxo_model.grid_file,
        tpxo_model.model_file['v'], type='v', method='spline',
        scale=tpxo_model.scale, compressed=tpxo_model.compressed
    )
print(ampv, phv)

[[8.668813118062117 35.76033562499962 8.683169526784722 23.45806202528826
  47.0764734010927 213.35622310706765 0.25969733329667843
  77.97409128150774 21.970233688959325 1.4524161022783755
  0.7655314554124386 0.5279226470822542 6.575219141717764
  2.862807843370184 0.5041899446945962]] [[113.33001625627412 132.03097175093149 160.45599243485512
  162.9877113337471 67.40618069237203 76.19530909747947 157.4830159948237
  92.9992478520575 89.28646630226528 246.77373431721838
  348.21351820971336 195.81061250443756 37.385402422090806
  258.2549476327405 63.84744039893113]]


In [11]:
filtered_coords = set()
for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx>=1 and ilon_idx>=1 and ilat_idx<=len(latz)-2 and ilon_idx<=len(lonz)-2:
        neighbors = [
            (ilat_idx+1, ilon_idx),
            (ilat_idx, ilon_idx+1),
            (ilat_idx+1, ilon_idx+1)
        ]
    if all(neighbor in coords_to_recompute for neighbor in neighbors):
        filtered_coords.add((ilat_idx, ilon_idx))
        #print(neighbors)
        for neighbor in neighbors:
            filtered_coords.add(neighbor)

#coords_to_recompute = filtered_coords
print('After removing: ', len(filtered_coords))

After removing:  260768


In [12]:
filtered_coord3 = set()
for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx <= len(latz) - 3 and ilon_idx <= len(lonz) - 3:
        neighbors = [
            (ilat_idx+1, ilon_idx), 
            (ilat_idx+2, ilon_idx),
            (ilat_idx, ilon_idx+1), 
            (ilat_idx, ilon_idx+2),
            (ilat_idx+1, ilon_idx+1),
            (ilat_idx+1, ilon_idx+2),
            (ilat_idx+2, ilon_idx+1),
            (ilat_idx+2, ilon_idx+2)
        ]
        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            #print(neighbors)
            filtered_coord3.add((ilat_idx, ilon_idx))
            for neighbor in neighbors:
                filtered_coord3.add(neighbor)
#### Note this mehtod re-evaluate the points which may already had been included in previous neighborhoods 
#### so that they may be repeated in clusters and may to be recompute by extract_constant 
#### because the member in each cluster may be duplicated
print('After removing: ', len(filtered_coord3))                

After removing:  174938


In [None]:
#try one neibors to extract constants: [(1757, 10166), (1758, 10166), (1756, 10167), (1756, 10168), (1757, 10167), (1757, 10168), (1758, 10167), (1758, 10168)]
lon_chunk = lonz[10166:10169]
lat_chunk = latz[1756:1759]
lon_grid, lat_grid = np.meshgrid(lon_chunk, lat_chunk)
mampu, mphu, mD, mc = ATLAS.extract_constants(
    lon_grid.ravel(), lat_grid.ravel(),
    tpxo_model.grid_file,
    tpxo_model.model_file['u'], type='u', method='spline',
    scale=tpxo_model.scale, compressed=tpxo_model.compressed)
print(mampu, mphu)

In [16]:
filtered_coord5 = set()

for ilat_idx, ilon_idx in coords_to_recompute:
    if ilat_idx <= len(latz) - 5 and ilon_idx <= len(lonz) - 5:
        is_square = True
        for i in range(5):
            for j in range(5):
                if (ilat_idx + i, ilon_idx + j) not in coords_to_recompute:
                    is_square = False
                    break
            if not is_square:
                break
        if is_square:
            for i in range(5):
                for j in range(5):
                    filtered_coord5.add((ilat_idx + i, ilon_idx + j))

#coords_to_recompute = filtered_coord5
print('After removing: ', len(filtered_coord5))   

After removing:  0


In [None]:
#filtered_coords = set()
#for ilat_idx, ilon_idx in coords_to_recompute:
#    if ilat_idx>=1 and ilon_idx>=1 and ilat_idx<=len(latz)-2 and ilon_idx<=len(lonz)-2:
#        neighbors = [
#            (ilat_idx+1, ilon_idx),
#            (ilat_idx, ilon_idx+1),
#            (ilat_idx+1, ilon_idx+1)
#        ]
#    if all(neighbor in coords_to_recompute for neighbor in neighbors):
#        filtered_coords.add((ilat_idx, ilon_idx))
#        for neighbor in neighbors:
#            filtered_coords.add(neighbor)

In [27]:
def filter_and_form_cluster2(coords_to_recompute):
    filtered_coords = set()
    clusters = []

    for ilat_idx, ilon_idx in coords_to_recompute:
        # Creating a list for all 5x5 neighbors of the current point
        neighbors = [(ilat_idx + dlat, ilon_idx + dlon) for dlat in range(2) for dlon in range(2)]
        
        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            clusters.append(neighbors)
            for neighbor in neighbors:
                filtered_coords.add(neighbor)

    return clusters, filtered_coords

In [28]:
f2_cluster, f2_coords = filter_and_form_cluster2(coords_to_recompute)

In [29]:
print(len(f2_coords))
print(f2_cluster[0:10])
print(list(f2_coords)[0:10])

194604
[[(3650, 7363), (3650, 7364), (3651, 7363), (3651, 7364)], [(3789, 3765), (3789, 3766), (3790, 3765), (3790, 3766)], [(4327, 408), (4327, 409), (4328, 408), (4328, 409)], [(4908, 3571), (4908, 3572), (4909, 3571), (4909, 3572)], [(1852, 479), (1852, 480), (1853, 479), (1853, 480)], [(308, 4816), (308, 4817), (309, 4816), (309, 4817)], [(3565, 7440), (3565, 7441), (3566, 7440), (3566, 7441)], [(4790, 5013), (4790, 5014), (4791, 5013), (4791, 5014)], [(5036, 8452), (5036, 8453), (5037, 8452), (5037, 8453)], [(4448, 711), (4448, 712), (4449, 711), (4449, 712)]]
[(3650, 7363), (1532, 8984), (4419, 10577), (5071, 8288), (3574, 8307), (4517, 532), (3789, 3765), (4143, 4263), (4327, 408), (5101, 10382)]


In [None]:
remaining_points = coords_to_recompute - f2_coords
if remaining_points:
    print("Remaining points at final: ", len(remaining_points))

In [10]:
def process_chunk(cluster, lonz, latz, tpxo_model, var_type, cluster_idx, cluster_num):
    start_lat, end_lat, start_lon, end_lon = cluster
    lon_chunk = lonz[start_lon:end_lon]
    lat_chunk = latz[start_lat:end_lat]
    lon_grid, lat_grid = np.meshgrid(lon_chunk, lat_chunk)
    scale = 1e-4 ## replace tpxo_model.scale before pyTMD new release

    amp, ph, D, c = ATLAS.extract_constants(
        lon_grid.ravel(), lat_grid.ravel(),
        tpxo_model.grid_file,
        tpxo_model.model_file[var_type], type=var_type, method='spline',
        scale=scale, compressed=tpxo_model.compressed, extrapolate=True)

    # reshape back amp and ph
    amp = amp.reshape((end_lat - start_lat, end_lon - start_lon, -1))
    ph = ph.reshape((end_lat - start_lat, end_lon - start_lon, -1))
    print(f"Cluster index: {cluster_idx}/{cluster_num} for variable: {var_type} wtih scale: {scale}")

    return (cluster_idx, cluster_num, start_lat, end_lat, start_lon, end_lon, var_type, amp, ph)


def filter_and_form_clusters(coords_to_recompute, neighborx=1, neighbory=4):
    coords_to_recompute = set(coords_to_recompute)  # Ensure it's a set for efficient removal
    clusters = []

    while coords_to_recompute:
        ilat_idx, ilon_idx = next(iter(coords_to_recompute))  # Take one coord from the set without removing it

        neighbors = [(ilat_idx + dlat, ilon_idx + dlon) for dlat in range(neighbory) for dlon in range(neighborx)]

        if all(neighbor in coords_to_recompute for neighbor in neighbors):
            start_lat, start_lon = min(neighbors, key=lambda x: (x[0], x[1]))
            end_lat, end_lon = max(neighbors, key=lambda x: (x[0], x[1]))
            # +1 because we want to include the last point when slicing
            clusters.append((start_lat, end_lat + 1, start_lon, end_lon + 1))

            for neighbor in neighbors:
                coords_to_recompute.discard(neighbor)  # Remove these neighbors from further consideration
        else:
            coords_to_recompute.discard((ilat_idx, ilon_idx))  # Remove the current point if not all its neighbors are in the set

    return clusters


In [31]:
ReComputeCoords = False
# input_file = "tpxo9_fillna04.zarr"
# ds = xr.open_zarr(input_file)

if ReComputeCoords:
    # Check for NaNs in the four variables
    coords_to_recompute = set()
    for var in ['u_amp', 'v_amp']:
        print("Now process var to find na: ", var)
        nan_locs = np.argwhere(np.isnan(ds[var].values).any(axis=-1))
        for loc in nan_locs:
            ilat_idx, ilon_idx = loc
            if not bathy_mask[ilat_idx, ilon_idx] and bathy_data[ilat_idx, ilon_idx] > 0.0:
                coords_to_recompute.add((ilat_idx, ilon_idx))



In [32]:
total_points = len(coords_to_recompute)
print(f"Total points to process: {total_points}")

Total points to process: 380901


In [11]:
clusters_1x3 = filter_and_form_clusters(coords_to_recompute, neighborx=1, neighbory=3)
print(clusters_1x3)
print(len(clusters_1x3)) #44431(09) -> 29396(10)

[(4529, 4532, 542, 543), (4494, 4497, 706, 707), (4053, 4056, 963, 964), (4386, 4389, 614, 615), (4018, 4021, 1127, 1128), (3916, 3919, 850, 851), (4598, 4601, 655, 656), (538, 541, 8795, 8796), (343, 346, 8327, 8328), (4517, 4520, 8371, 8372), (2180, 2183, 4212, 4213), (1558, 1561, 8595, 8596), (2323, 2326, 1487, 1488), (4366, 4369, 631, 632), (4986, 4989, 2941, 2942), (3919, 3922, 292, 293), (1890, 1893, 450, 451), (4730, 4733, 2197, 2198), (3977, 3980, 7063, 7064), (2035, 2038, 3433, 3434), (3994, 3997, 1001, 1002), (3959, 3962, 1165, 1166), (2343, 2346, 3936, 3937), (3425, 3428, 1581, 1582), (3812, 3815, 7130, 7131), (2190, 2193, 3670, 3671), (4626, 4629, 638, 639), (4119, 4122, 8943, 8944), (4854, 4857, 1993, 1994), (3766, 3769, 719, 720), (4638, 4641, 648, 649), (4760, 4763, 2331, 2332), (3387, 3390, 1073, 1074), (3612, 3615, 3635, 3636), (4073, 4076, 8948, 8949), (1991, 1994, 9445, 9446), (4033, 4036, 8768, 8769), (2353, 2356, 412, 413), (4025, 4028, 956, 957), (3532, 3535, 1023

In [13]:
clusters_5x1 = filter_and_form_clusters(coords_to_recompute, neighborx=5, neighbory=1)
print(clusters_5x1)
print(len(clusters_5x1)) #15857(10)

[(5071, 5072, 8288, 8293), (5400, 5401, 9668, 9673), (5400, 5401, 3252, 3257), (5400, 5401, 6635, 6640), (238, 239, 5925, 5930), (5400, 5401, 219, 224), (5056, 5057, 10208, 10213), (5400, 5401, 10018, 10023), (5400, 5401, 3602, 3607), (4501, 4502, 722, 727), (5029, 5030, 7257, 7262), (3998, 3999, 101, 106), (4809, 4810, 8422, 8427), (3589, 3590, 8160, 8165), (5147, 5148, 8767, 8772), (4913, 4914, 2305, 2310), (5088, 5089, 616, 621), (5400, 5401, 8058, 8063), (5400, 5401, 1642, 1647), (5400, 5401, 5025, 5030), (384, 385, 6368, 6373), (4918, 4919, 4072, 4077), (5400, 5401, 9481, 9486), (5400, 5401, 3065, 3070), (4878, 4879, 3892, 3897), (5400, 5401, 6448, 6453), (5400, 5401, 32, 37), (5400, 5401, 9831, 9836), (4788, 4789, 592, 597), (4495, 4496, 720, 725), (4948, 4949, 7589, 7594), (5400, 5401, 4488, 4493), (5224, 5225, 6706, 6711), (5400, 5401, 7871, 7876), (5400, 5401, 1455, 1460), (5400, 5401, 4838, 4843), (4786, 4787, 8708, 8713), (5111, 5112, 547, 552), (3348, 3349, 3268, 3273), (49

In [12]:
clusters_3x1 = filter_and_form_clusters(coords_to_recompute, neighborx=3, neighbory=1)
print(clusters_3x1)
print(len(clusters_3x1))

[(5071, 5072, 8288, 8291), (3574, 3575, 8307, 8310), (3333, 3334, 2179, 2182), (4947, 4948, 3499, 3502), (4842, 4843, 5203, 5206), (2725, 2726, 3262, 3265), (5400, 5401, 9668, 9671), (5400, 5401, 3252, 3255), (5400, 5401, 6635, 6638), (238, 239, 5925, 5928), (5400, 5401, 219, 222), (1756, 1757, 10166, 10169), (5056, 5057, 10208, 10211), (5400, 5401, 10018, 10021), (5400, 5401, 3602, 3605), (3048, 3049, 8210, 8213), (4501, 4502, 722, 725), (5029, 5030, 7257, 7260), (2582, 2583, 3684, 3687), (3998, 3999, 101, 104), (4443, 4444, 717, 720), (4809, 4810, 8422, 8425), (3589, 3590, 8160, 8163), (5147, 5148, 8767, 8770), (4789, 4790, 6615, 6618), (4675, 4676, 9711, 9714), (4907, 4908, 7775, 7778), (5070, 5071, 3043, 3046), (1677, 1678, 558, 561), (4913, 4914, 2305, 2308), (4789, 4790, 549, 552), (2175, 2176, 4218, 4221), (5088, 5089, 616, 619), (4683, 4684, 9198, 9201), (5400, 5401, 8058, 8061), (5400, 5401, 1642, 1645), (4907, 4908, 1709, 1712), (4888, 4889, 3187, 3190), (5400, 5401, 5025, 50

In [37]:
print(ds.coords['lon'].values[5752], ds.coords['lat'].values[1114])

191.7666644997597 -52.86666875767544


In [43]:
def find_nearest_lonlat(dz, ilon, ilat):
    ilon = ilon + 360 if ilon < 0 else ilon
    abs_diff_lat = np.abs(dz['lat'].values - ilat)  # Calculate absolute difference from desired value
    nearest_lat_index = np.argmin(abs_diff_lat)  # Get index of smallest difference
    nearest_lat = dz['lat'].values[nearest_lat_index]
    abs_diff_lon = np.abs(dz['lon'].values - ilon)  
    nearest_lon_index = np.argmin(abs_diff_lon)  
    nearest_lon = dz['lon'].values[nearest_lon_index]
    return (nearest_lon, nearest_lat, nearest_lon_index, nearest_lat_index)

def find_cluster_index(clusters, lon_index, lat_index):
    for idx, (lat_start, lat_end, lon_start, lon_end) in enumerate(clusters):
        if lat_start <= lat_index < lat_end and lon_start <= lon_index < lon_end:
            return idx  # Return index of the found cluster
    return -1  # Return -1 if no cluster contains the provided indices

def find_recompoute_index(list_to_recompute, lon_index, lat_index):
    for idx, (lat, lon) in enumerate(list_to_recompute):
        if lat == lat_index and lon == lon_index:
            return idx  # Return index of the found coords
    return -1  # Return -1 if no coords contains the provided indices


In [61]:
clusters_chk = clusters_5x5
x1, y1, x1idx, y1idx = find_nearest_lonlat(ds, 122, -9) #existed example in 2x2: 191.76, -52.87
print(x1, y1, x1idx, y1idx)
chk_idx = find_cluster_index(clusters_chk, x1idx, y1idx)
print(chk_idx)
if chk_idx >= 0:
    print(clusters_chk[chk_idx])

chk_coords = find_recompoute_index(list(coords_to_recompute), x1idx, y1idx)
print(chk_coords)
if chk_coords >= 0:
    print(list(coords_to_recompute)[chk_coords])

chk_ds = ds.isel(lat=slice(y1idx, y1idx+1), lon=slice(x1idx, x1idx+1))
print(chk_ds['u_amp'].values, chk_ds['u_ph'].values, chk_ds['v_amp'].values, chk_ds['v_ph'].values)

121.99999862182065 -9.000000355972489 3659 2430
534
(2428, 2433, 3656, 3661)
475496
(2430, 3659)
[[[6.91453119e-02 1.65097759e-01 5.43460462e-02 2.41392962e-01
   5.68252491e-01 3.01851075e+00 8.67283057e-02 1.44436726e+00
   4.01194331e-01 7.50599604e-03 2.92606777e-03 2.61325997e-03
   9.71716825e-02 1.63083734e-01 6.47662973e-02]]] [[[ 66.29900703  68.61961161 102.20783183 104.95377279  69.63127191
    93.62177178 349.26245427 160.7580917  160.74522363 324.61925127
    84.36787633 205.05400313  17.86357601  51.81990074 356.71422352]]] [[[1.94456026e-02 4.49745854e-02 1.33725985e-02 6.27619769e-02
   1.64657124e-01 8.59102122e-01 2.52046901e-02 4.20720427e-01
   1.16438867e-01 2.24817211e-03 6.23693231e-04 7.39258797e-04
   2.92209951e-02 3.08506264e-02 8.58493185e-03]]] [[[         nan          nan          nan          nan          nan
            nan          nan          nan 158.61111736 323.33636416
    68.13856822 204.61323736  15.72410121  22.13877412 332.72600963]]]


In [18]:
# Filter coordinates and form 1x4 clusters (fillna05:445 -> 06, with extrapolate:0)
#clusters_1x4 = filter_and_form_clusters(coords_to_recompute, neighborx=1, neighbory=4)
#print(clusters_1x4)
#print(len(clusters_1x4))
# (filter06:475 -> 07, with extrapolate: 0)
#clusters_2x3 = filter_and_form_clusters(coords_to_recompute, neighborx=2, neighbory=3)
#print(clusters_2x3)
#print(len(clusters_2x3))

[]
0


In [10]:
clusters_test = filter_and_form_clusters(coords_to_recompute, neighborx=2, neighbory=2)
print(clusters_test)
print(len(clusters_test))

[(4327, 4329, 408, 410), (1852, 1854, 479, 481), (308, 310, 4816, 4818), (4673, 4675, 9689, 9691), (336, 338, 9092, 9094), (2268, 2270, 1804, 1806), (4460, 4462, 4454, 4456), (1656, 1658, 9130, 9132), (4625, 4627, 8843, 8845), (2989, 2991, 8248, 8250), (4665, 4667, 9665, 9667), (2872, 2874, 3475, 3477), (2025, 2027, 10153, 10155), (3866, 3868, 3528, 3530), (4788, 4790, 2015, 2017), (3459, 3461, 1810, 1812), (4632, 4634, 646, 648), (4437, 4439, 8418, 8420), (3964, 3966, 442, 444), (4729, 4731, 2352, 2354), (3347, 3349, 3260, 3262), (1982, 1984, 3398, 3400), (4470, 4472, 556, 558), (1475, 1477, 5287, 5289), (4213, 4215, 9009, 9011), (1770, 1772, 7433, 7435), (4468, 4470, 8373, 8375), (3926, 3928, 576, 578), (2869, 2871, 3472, 3474), (4378, 4380, 430, 432), (4786, 4788, 8334, 8336), (2271, 2273, 1809, 1811), (3868, 3870, 4249, 4251), (3149, 3151, 6359, 6361), (3144, 3146, 6365, 6367), (3842, 3844, 637, 639), (4909, 4911, 3666, 3668), (4838, 4840, 5304, 5306), (1994, 1996, 9450, 9452), (11

In [65]:
print(maxWorkers)

6


In [66]:
ReTest = True
if ReTest:
    with ProcessPoolExecutor(max_workers=maxWorkers) as executor:
        futures_list = []
        total_clusters = len(clusters)
        for idx, chunk in enumerate(clusters):
            for var_type in ['u', 'v']:
                future = executor.submit(process_chunk, chunk, lonz, latz, tpxo_model, var_type, idx, total_clusters)
                futures_list.append(future)

        for future in as_completed(futures_list):
            idx_processed, cluster_num, start_lat, end_lat, start_lon, end_lon, var_type_processed, amp, ph = future.result()
            print(f"Processed cluster index: {idx_processed}/{cluster_num} for variable: {var_type_processed}")

            # Refill zarr dataset based on the variable type
            if var_type_processed == 'u':
                ds['u_amp'][start_lat:end_lat, start_lon:end_lon, :] = amp
                ds['u_ph'][start_lat:end_lat, start_lon:end_lon, :] = ph
            else:
                ds['v_amp'][start_lat:end_lat, start_lon:end_lon, :] = amp
                ds['v_ph'][start_lat:end_lat, start_lon:end_lon, :] = ph    

Cluster index: 1/3214 for variable: u wtih scale: 0.0001
Cluster index: 2/3214 for variable: u wtih scale: 0.0001
Processed cluster index: 1/3214 for variable: u
Processed cluster index: 2/3214 for variable: u
Cluster index: 0/3214 for variable: u wtih scale: 0.0001
Processed cluster index: 0/3214 for variable: u
Cluster index: 2/3214 for variable: v wtih scale: 0.0001
Processed cluster index: 2/3214 for variable: v
Cluster index: 1/3214 for variable: v wtih scale: 0.0001
Cluster index: 0/3214 for variable: v wtih scale: 0.0001
Processed cluster index: 1/3214 for variable: v
Processed cluster index: 0/3214 for variable: v
Cluster index: 3/3214 for variable: u wtih scale: 0.0001
Processed cluster index: 3/3214 for variable: u
Cluster index: 5/3214 for variable: u wtih scale: 0.0001
Processed cluster index: 5/3214 for variable: u
Cluster index: 4/3214 for variable: u wtih scale: 0.0001
Processed cluster index: 4/3214 for variable: u
Cluster index: 4/3214 for variable: v wtih scale: 0.000

In [38]:
print(ds)

<xarray.Dataset>
Dimensions:       (constituents: 15, lat: 5401, lon: 10800)
Coordinates:
  * constituents  (constituents) <U3 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 -90.0 -89.97 -89.93 -89.9 ... 89.93 89.97 90.0
  * lon           (lon) float64 0.03333 0.06667 0.1 0.1333 ... 359.9 360.0 360.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>


In [71]:
ReSave = True
#if ReSave:
#    outfile = 'tpxo9_tmp.zarr'
#    print("Start to re-write zarr dataset...")
#    ds.to_zarr(outfile, mode='w')
####ds.close()


In [75]:
import dask.array as da
from dask import delayed #, compute
from dask.diagnostics import ProgressBar

def save_dataset_dask(ds, store, chunk_size):
    # Convert xarray dataset to dask-backed arrays
    ds_dask = ds.chunk({'lat': chunk_size, 'lon': chunk_size})
    dask_arrays = {var: da.from_array(ds_dask[var].values, chunks=(chunk_size, chunk_size, -1)) for var in ds.data_vars}

    # Create placeholder arrays in the Zarr store
    for var_name, dask_arr in dask_arrays.items():
        zarr_arr = store.empty(var_name, shape=dask_arr.shape, dtype=dask_arr.dtype, chunks=dask_arr.chunksize)
        zarr_arr.attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']

    # Store dask-backed arrays to Zarr with parallel writes
    with ProgressBar():
        for var_name, dask_arr in dask_arrays.items():
            print("Current process: ", var_name, dask_arr)
            da.to_zarr(dask_arr, store[var_name], compute=True)


In [76]:
import xarray as xr
import zarr
from numcodecs import MsgPack
if ReSave:
   zarr_path = "../data/tmp_tpxo9.zarr"
   store = zarr.open(zarr_path, mode='w')

# Save dimension data
   for dim_name in ['lat', 'lon', 'constituents']:
        data = ds[dim_name].values
        chunks = ds[dim_name].encoding.get('chunks', ds[dim_name].shape)
        dtype = ds[dim_name].dtype
    
        # Check if dtype is object and handle accordingly
        if dtype == object:
            codec = MsgPack()
            arr = store.array(dim_name, data=data, chunks=chunks, dtype=dtype, object_codec=codec)
        else:
            arr = store.array(dim_name, data=data, chunks=chunks, dtype=dtype)
    
        arr.attrs['_ARRAY_DIMENSIONS'] = [dim_name]


In [77]:
# Assuming chunking over lat and lon as in your example
chunk_size = 338



In [78]:
save_dataset_dask(ds, store, chunk_size=chunk_size)


Current process:  u_amp dask.array<array, shape=(5401, 10800, 15), dtype=float64, chunksize=(338, 338, 15), chunktype=numpy.ndarray>
[########################################] | 100% Completed | 102.35 s
Current process:  u_ph dask.array<array, shape=(5401, 10800, 15), dtype=float64, chunksize=(338, 338, 15), chunktype=numpy.ndarray>
[########################################] | 100% Completed | 116.35 s
Current process:  v_amp dask.array<array, shape=(5401, 10800, 15), dtype=float64, chunksize=(338, 338, 15), chunktype=numpy.ndarray>
[########################################] | 100% Completed | 196.29 s
Current process:  v_ph dask.array<array, shape=(5401, 10800, 15), dtype=float64, chunksize=(338, 338, 15), chunktype=numpy.ndarray>
[########################################] | 100% Completed | 84.82 ss
Current process:  z_amp dask.array<array, shape=(5401, 10800, 15), dtype=float64, chunksize=(338, 338, 15), chunktype=numpy.ndarray>
[########################################] | 100% Com

In [None]:
# Create placeholder arrays with `_ARRAY_DIMENSIONS` attribute
for var_name in ds.data_vars:
    shape = ds[var_name].shape
    dtype = ds[var_name].dtype
    chunks = (chunk_size, chunk_size, ds['constituents'].shape[0])  # Assuming 3D data with constituents as the third dimension
    arr = store.empty(var_name, shape=shape, dtype=dtype, chunks=chunks)
    arr.attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']


In [30]:
# Write data in chunks
for i in range(0, len(ds['lat']), chunk_size):
    for j in range(0, len(ds['lon']), chunk_size):

        # Extract chunk from dataset
        ds_chunk = ds.isel(lat=slice(i, i+chunk_size), lon=slice(j, j+chunk_size))

        # Write chunk to appropriate location in Zarr store
        for var_name, variable in ds_chunk.data_vars.items():
            store[var_name][i:i+chunk_size, j:j+chunk_size, :] = variable.values


In [8]:
import zarr
store = zarr.open('../data/tpxo9.zarr', mode='a')
print(store)
print(store['u_amp'])
print(list(store['u_amp'].attrs))
print(store['u_amp'].attrs['_ARRAY_DIMENSIONS'] )

<zarr.hierarchy.Group '/'>
<zarr.core.Array '/u_amp' (5401, 10800, 15) float64>
['_ARRAY_DIMENSIONS']
['lat', 'lon', 'constituents']


In [9]:
#import zarr
ARRAY_DIMENSIONS_Err = False
if ARRAY_DIMENSIONS_Err:
    store = zarr.open('tpxo9.zarr', mode='a')
#### For example, for a data variable `u_amp` that has dimensions ('lat', 'lon', 'constituents'):
#### store['u_amp'].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']
#### store['u_ph'].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']
#### store['v_amp'].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']
#### store['v_ph'].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']
#### store['z_amp'].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']
#### store['z_ph'].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']

    for var_name in store.array_keys():
        store[var_name].attrs['_ARRAY_DIMENSIONS'] = ['lat', 'lon', 'constituents']

In [61]:
# Correcting the attributes
if ARRAY_DIMENSIONS_Err:
    store['constituents'].attrs['_ARRAY_DIMENSIONS'] = ['constituents']
    store['lat'].attrs['_ARRAY_DIMENSIONS'] = ['lat']
    store['lon'].attrs['_ARRAY_DIMENSIONS'] = ['lon']

# Confirming the changes
for var_name in store.array_keys():
    print(var_name, store[var_name].attrs['_ARRAY_DIMENSIONS'])


constituents ['constituents']
lat ['lat']
lon ['lon']
u_amp ['lat', 'lon', 'constituents']
u_ph ['lat', 'lon', 'constituents']
v_amp ['lat', 'lon', 'constituents']
v_ph ['lat', 'lon', 'constituents']
z_amp ['lat', 'lon', 'constituents']
z_ph ['lat', 'lon', 'constituents']


In [32]:
zarr.convenience.consolidate_metadata('../data/tpxo9.zarr')

<zarr.hierarchy.Group '/'>

In [51]:
dz = xr.open_zarr('../data/tpxo9.zarr', chunks='auto', decode_times=False, consolidated=True) #, group="data_group") #, consolidated=False)
print(dz)

<xarray.Dataset>
Dimensions:       (constituents: 15, lat: 5401, lon: 10800)
Coordinates:
  * constituents  (constituents) object 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 -90.0 -89.97 -89.93 -89.9 ... 89.93 89.97 90.0
  * lon           (lon) float64 0.03333 0.06667 0.1 0.1333 ... 359.9 360.0 360.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(338, 338, 15), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(338, 338, 15), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(338, 338, 15), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(338, 338, 15), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(338, 338, 15), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(338, 338, 15), meta=np.ndarray>


In [52]:
lat_values = dz['lat'].values
lon_values = dz['lon'].values

is_lat_monotonic = np.all(np.diff(lat_values) > 0)
is_lon_monotonic = np.all(np.diff(lon_values) > 0)

print("Is lat monotonic?", is_lat_monotonic)
print("Is lon monotonic?", is_lon_monotonic)


Is lat monotonic? False
Is lon monotonic? True


In [53]:
abs_diff_lat = np.abs(dz['lat'].values - 30)  # Calculate absolute difference from desired value
nearest_lat_index = np.argmin(abs_diff_lat)  # Get index of smallest difference
nearest_lat_value = dz['lat'].values[nearest_lat_index]
print(nearest_lat_index, nearest_lat_value)
print(dz['lat'].values[2698:2702])
print(latz[2698:2702])

2700 nan
[-0.06666667 -0.03333333         nan  0.03333333]
[-0.06666667 -0.03333333  0.          0.03333333]


In [54]:
dz['lat'].values[2700] = 0
dz = dz.sortby('lat')

In [31]:
ilon = 335 #122.26672
ilat = 30 #23.76175
#grid_sz = 1/30

dsub = dz.sel(lon=slice(ilon, ilon+5), lat=slice(ilat, ilat+5))
print(dsub)

<xarray.Dataset>
Dimensions:       (constituents: 15, lat: 150, lon: 150)
Coordinates:
  * constituents  (constituents) object 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 30.0 30.03 30.07 30.1 ... 34.87 34.9 34.93 34.97
  * lon           (lon) float64 335.0 335.1 335.1 335.1 ... 339.9 340.0 340.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(118, 90, 15), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(118, 90, 15), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(118, 90, 15), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(118, 90, 15), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(118, 90, 15), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(118, 90, 15), meta=np.ndarray>


In [91]:
#dz.load()
#del dz

import gc
gc.collect()

0

In [68]:
#Three methods to re-save: compression, direct-overwrite, save by chunk of data
from numcodecs import Blosc

compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE)
encoding = {var: {'compressor': compressor} for var in dz.data_vars}

dz.to_zarr('tpxo9_compress.zarr', mode='w', encoding=encoding)

<xarray.backends.zarr.ZarrStore at 0x7fabe4a7d700>

In [71]:
#dz.to_zarr('tpxo9.zarr', mode='w')

<xarray.backends.zarr.ZarrStore at 0x7fabe4a7d620>

In [74]:
store = zarr.open('tpxo9.zarr', mode='a')
for var_name in dz.data_vars:
    store[var_name][:] = dz[var_name].values

zarr.convenience.consolidate_metadata('tpxo9.zarr')

<zarr.hierarchy.Group '/'>

In [None]:
#zarr_path = "tpxo9.zarr"
#store = zarr.open(zarr_path, mode='w')
#data_group = store.create_group("data_group")

# Save dimension data
#for dim_name in ['lat', 'lon', 'constituents']:
#    data_group.array(dim_name, data=ds[dim_name].values, chunks=ds[dim_name].encoding.get('chunks', ds[dim_name].shape), dtype=ds[dim_name].dtype)

#for var_name in ds.data_vars:
#    shape = ds[var_name].shape
#    dtype = ds[var_name].dtype
#    chunks = (chunk_size, shape[1])  # Assuming chunking only over lat for simplicity. Adjust if needed.
    
#    # This will create the Zarr array with the necessary metadata including `_ARRAY_DIMENSIONS`
#    data_group.zeros(name=var_name, shape=shape, dtype=dtype, chunks=chunks)

# Assuming chunking over lat as in your example
#for i in range(0, len(ds['lat']), chunk_size):
#    for j in range(0, len(ds['lon']), chunk_size):
        
#        # Extract chunk from dataset
#        ds_chunk = ds.isel(lat=slice(i, i+chunk_size), lon=slice(j, j+chunk_size))
        
#        # Write chunk to appropriate location in Zarr store
#        for var_name, variable in ds_chunk.data_vars.items():
#            data_group[var_name][i:i+chunk_size, j:j+chunk_size] = variable.values


In [40]:
print(ds.data_vars.items())

ItemsView(Data variables:
    u_amp    (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    u_ph     (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_amp    (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_ph     (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_amp    (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_ph     (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>)


In [49]:
# Assuming chunking over lat and lon as in your example
chunk_size = 338

#for i in range(0, len(ds['lat']), chunk_size):
#    for j in range(0, len(ds['lon']), chunk_size):
        
#        # Extract chunk from dataset
#        ds_chunk = ds.isel(lat=slice(i, i+chunk_size), lon=slice(j, j+chunk_size))
        
#        # Write chunk to appropriate location in Zarr store
#        for var_name, variable in ds_chunk.data_vars.items():
#            store['data_group'][var_name][i:i+chunk_size, j:j+chunk_size] = variable.values
for i in range(0, len(ds['lat']), chunk_size):
    for j in range(0, len(ds['lon']), chunk_size):
        
        ds_chunk = ds.isel(lat=slice(i, i+chunk_size), lon=slice(j, j+chunk_size))
        
        for var_name, variable in ds_chunk.data_vars.items():
            store[var_name][i:i+chunk_size, j:j+chunk_size] = variable.values


In [45]:
print(ds)

<xarray.Dataset>
Dimensions:       (constituents: 15, lat: 5401, lon: 10800)
Coordinates:
  * constituents  (constituents) <U3 'q1' 'o1' 'p1' 'k1' ... '2n2' 'mf' 'mm'
  * lat           (lat) float64 -90.0 -89.97 -89.93 -89.9 ... 89.93 89.97 90.0
  * lon           (lon) float64 0.03333 0.06667 0.1 0.1333 ... 359.9 360.0 360.0
Data variables:
    u_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    u_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    v_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_amp         (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>
    z_ph          (lat, lon, constituents) float64 dask.array<chunksize=(113, 113, 8), meta=np.ndarray>


In [31]:
from src.model_utils import *
from src.model_plot import *

In [41]:
x0, y0, x1, y1 = 118.0, 20.0, 129.75, 31.25 
#x0, y0, x1, y1 = 122.26672, 23.76175, 129.75, 31.25 #123.5, 28 #to test v is empty bug 
#x0, y0, x1, y1 = 123.75, 23.76175, 129.75, 31.25
grid_sz = 1/30
uvsub = ds.sel(lon=slice(x0-grid_sz, x1+grid_sz), lat=slice(y0-grid_sz, y1+grid_sz))
print(uvsub['u_amp'].values)
print(uvsub['u_ph'].values)


[[[3.40780429e+01 1.76900864e+02 7.33603518e+01 ... 7.70306980e+00
   2.52180309e+00 1.13589369e+00]
  [3.41874214e+01 1.77619202e+02 7.37721867e+01 ... 7.76010437e+00
   2.53577729e+00 1.12995988e+00]
  [3.43605686e+01 1.78656682e+02 7.43260726e+01 ... 7.83293066e+00
   2.55519998e+00 1.12495190e+00]
  ...
  [1.01463628e+01 5.59750145e+01 2.10913039e+01 ... 2.23903531e+00
   1.08214373e+00 5.79075154e-01]
  [1.01060248e+01 5.57398869e+01 2.09988956e+01 ... 2.24079671e+00
   1.08868636e+00 5.67125280e-01]
  [1.00693758e+01 5.55214496e+01 2.09130455e+01 ... 2.24051345e+00
   1.09571709e+00 5.57695126e-01]]

 [[3.44736787e+01 1.79628687e+02 7.43019971e+01 ... 7.68110040e+00
   2.56000119e+00 1.19766531e+00]
  [3.45859762e+01 1.80341152e+02 7.47122392e+01 ... 7.74710828e+00
   2.58265538e+00 1.19563928e+00]
  [3.47561387e+01 1.81338824e+02 7.52513701e+01 ... 7.82841818e+00
   2.61081451e+00 1.19304673e+00]
  ...
  [1.01541993e+01 5.60996469e+01 2.11035521e+01 ... 2.24029592e+00
   1.07003

In [28]:
from datetime import datetime, timedelta, timezone
start_date = datetime(2023, 7, 25)
end_date = datetime(2023, 7, 28)


tide_time, dtime = get_tide_time(start_date, end_date)
print(tide_time.shape)
print(tide_time)

(73,)
[11528.         11528.04166667 11528.08333333 11528.125
 11528.16666667 11528.20833333 11528.25       11528.29166667
 11528.33333333 11528.375      11528.41666667 11528.45833333
 11528.5        11528.54166667 11528.58333333 11528.625
 11528.66666667 11528.70833333 11528.75       11528.79166667
 11528.83333333 11528.875      11528.91666667 11528.95833333
 11529.         11529.04166667 11529.08333333 11529.125
 11529.16666667 11529.20833333 11529.25       11529.29166667
 11529.33333333 11529.375      11529.41666667 11529.45833333
 11529.5        11529.54166667 11529.58333333 11529.625
 11529.66666667 11529.70833333 11529.75       11529.79166667
 11529.83333333 11529.875      11529.91666667 11529.95833333
 11530.         11530.04166667 11530.08333333 11530.125
 11530.16666667 11530.20833333 11530.25       11530.29166667
 11530.33333333 11530.375      11530.41666667 11530.45833333
 11530.5        11530.54166667 11530.58333333 11530.625
 11530.66666667 11530.70833333 11530.75       11

In [29]:
tide_curr = get_tide_map(uvsub, tide_time[0:1])
print(tide_curr)

{'u': masked_array(
  data=[[[23.72593027829161],
         [25.157223797329543],
         [26.646060352055372],
         ...,
         [-41.00181919944255],
         [-41.126158320603615],
         [-41.260785965204605]],

        [[22.157362063894134],
         [23.676802492273527],
         [25.247048927128716],
         ...,
         [-40.88541130157549],
         [-41.019718328949295],
         [-41.15926368278692]],

        [[20.96925901833982],
         [22.475329958982613],
         [24.01451617892337],
         ...,
         [-40.74241811478311],
         [-40.90498343378043],
         [-41.07228971171635]],

        ...,

        [[--],
         [--],
         [--],
         ...,
         [-378.94717870579996],
         [-304.92537633941305],
         [-248.50507414305122]],

        [[--],
         [--],
         [--],
         ...,
         [-450.0165458781292],
         [-367.09003822535476],
         [-297.0294135774241]],

        [[--],
         [--],
         [--],
   

In [None]:
plot_current_map(x, y, u, v, mag, dtime[0])

In [None]:
#North Atlantic
x1, y1, u1, v1, mag1 = get_current_map(280, 0, 360, 60, ds, tide_time[0:1], mask_grid=5)
