## installation

change to directory icon4py (directory root)
activate virtual env
```
pip install xugrid
pip install pymetis
```
change to `model/common/src/icon4py/model/common/decomposition/

In [1]:
import xugrid as xu
import functools as ft
from matplotlib import pyplot as plt
import numpy as np
import scipy as sp


In [2]:
#load the data set 

ds = xu.open_dataset('data/icon_grid_0013_R02B04_R_ugrid.nc')
grid = ds.ugrid.grid
print(f"loaded dataset {type(ds)}")
print(f"loaded dataset {type(ds.ugrid)}")

print("---------------------------")
print(f" ---- dimensions of entire dataset: {ds.dims}")
print("---------------------------")
print(" ---- grid property: ")
print(f" ---- type of accessor's grid: ugrid.grid {type(grid)}")
print(f" ---- grid: {grid}")



loaded dataset <class 'xugrid.core.wrap.UgridDataset'>
loaded dataset <class 'xugrid.core.dataset_accessor.UgridDatasetAccessor'>
---------------------------
---------------------------
 ---- grid property: 
 ---- type of accessor's grid: ugrid.grid <class 'xugrid.ugrid.ugrid2d.Ugrid2d'>
 ---- grid: <xarray.Dataset> Size: 3MB
Dimensions:                (cell: 20480, nv: 3, edge: 30720, nc: 2,
                            vertex: 10242)
Coordinates:
    vlon                   (vertex) float64 82kB ...
    vlat                   (vertex) float64 82kB ...
    clon                   (cell) float64 164kB ...
    clat                   (cell) float64 164kB ...
    elon                   (edge) float64 246kB ...
    elat                   (edge) float64 246kB ...
Dimensions without coordinates: cell, nv, edge, nc, vertex
Data variables:
    mesh                   int64 8B ...
    vertex_of_cell         (cell, nv) float64 492kB 0.0 1.0 ... 4.748e+03
    edge_vertices          (edge, nc) float64

Using vlon and vlat as projected x and y coordinates.
Using clon and clat as projected x and y coordinates.
Using elon and elat as projected x and y coordinates.


In [3]:
# decomposes the unstructured grid into 4 parts: 
# returns a list of n_part 
partitions = ds.ugrid.partition(n_part=4)
print(f"--- type(meshes)= {type(partitions)}")
part0 = partitions[0]
print(f"--- single partition: type of list entry:{type(part0)}")
print(f"--- single partition: type of list ugrid:{type(part0.ugrid)}")

print(f"--- single partition: grid {part0.grid}")

--- type(meshes)= <class 'list'>
--- single partition: type of list entry:<class 'xugrid.core.wrap.UgridDataset'>
--- single partition: type of list ugrid:<class 'xugrid.core.dataset_accessor.UgridDatasetAccessor'>
--- single partition: grid <xarray.Dataset> Size: 290kB
Dimensions:         (cell: 5120, mesh_nMax_face_nodes: 3, edge: 7770, two: 2,
                     vertex: 2651)
Coordinates:
    vlon            (vertex) float64 21kB 0.6283 1.257 1.885 ... 1.054 1.125
    vlat            (vertex) float64 21kB 1.54 1.517 1.54 ... -1.026 -1.013
Dimensions without coordinates: cell, mesh_nMax_face_nodes, edge, two, vertex
Data variables:
    mesh            int64 8B 0
    vertex_of_cell  (cell, mesh_nMax_face_nodes) int64 123kB 0 3 1 ... 2488 2650
    edge_vertices   (edge, two) int64 124kB 0 2 0 3 1 ... 2482 2650 2482 2488
Attributes:
    Conventions:  CF-1.9 UGRID-1.0


In [4]:

INVALID = -1 
horizontal_dims = ["cell", "edge", "vertex"]
part0.sizes["cell"]

    



    


5120

In [5]:

def inspect_sizes(ds, patches, dim_names):
    for d in dim_names:
        original = ds.sizes[d]
        _, summed = total_sizes(patches, d)
        sizes = list(map(lambda p: p.sizes[d], patches))
        if ds.sizes[d] == summed:
            print(f" matching: ds {d} size {original} = {summed} (partition sum) ({sizes})")

        else:
            
          print(f" NON MATCHING: {d}: ds {d} size {original} != {summed} (partition sum)({sizes}): difference = {summed - original} ")


            
def total_sizes(patches, dim_name):    
    size = ft.reduce(lambda a, b: a + b, map(lambda p: p.sizes[dim_name], patches))
    return dim_name, size

inspect_sizes(ds, partitions, horizontal_dims)

 matching: ds cell size 20480 = 20480 (partition sum) ([5120, 5120, 5120, 5120])
 NON MATCHING: edge: ds edge size 30720 != 31103 (partition sum)([7770, 7774, 7781, 7778]): difference = 383 
 NON MATCHING: vertex: ds vertex size 10242 != 10627 (partition sum)([2651, 2655, 2662, 2659]): difference = 385 


In [6]:
#inspect boundaries
grid0 = part0.ugrid.grid
grid0


<xarray.Dataset> Size: 290kB
Dimensions:         (cell: 5120, mesh_nMax_face_nodes: 3, edge: 7770, two: 2,
                     vertex: 2651)
Coordinates:
    vlon            (vertex) float64 21kB 0.6283 1.257 1.885 ... 1.054 1.125
    vlat            (vertex) float64 21kB 1.54 1.517 1.54 ... -1.026 -1.013
Dimensions without coordinates: cell, mesh_nMax_face_nodes, edge, two, vertex
Data variables:
    mesh            int64 8B 0
    vertex_of_cell  (cell, mesh_nMax_face_nodes) int64 123kB 0 3 1 ... 2488 2650
    edge_vertices   (edge, two) int64 124kB 0 2 0 3 1 ... 2482 2650 2482 2488
Attributes:
    Conventions:  CF-1.9 UGRID-1.0

In [7]:
grid0.edge_face_connectivity

array([[   1,   -1],
       [   0,   -1],
       [   0,   10],
       ...,
       [5117, 5119],
       [5118, 5119],
       [5119,   -1]])

In [8]:
grid0.edge_face_connectivity.shape

(7770, 2)

In [9]:
# find boundary edges (only 1 cell neighbor) - local indices on the part0
x, y, _ = sp.sparse.find(grid0.edge_face_connectivity == INVALID) 
x


array([   0,    1,    7,   10,   28, 1583, 1593, 1595, 1665, 1669, 1808,
       1862, 1868, 1872, 1874, 1875, 1882, 1890, 1892, 1895, 2314, 2315,
       2316, 2319, 2367, 2368, 2369, 2399, 2433, 2434, 2445, 2447, 2533,
       2538, 2546, 2548, 2552, 2555, 2557, 2558, 2559, 2562, 2564, 2568,
       2570, 2573, 2578, 2586, 2590, 2596, 2597, 2612, 2622, 2624, 2632,
       2646, 2649, 2669, 2672, 2697, 2701, 2730, 2734, 2736, 2738, 2768,
       2769, 4380, 4381, 4382, 4412, 4431, 4436, 4438, 4440, 4446, 4460,
       4462, 4463, 4464, 4467, 4468, 4472, 4475, 4476, 4483, 4841, 4926,
       4938, 4940, 4941, 4944, 4946, 4950, 4952, 5328, 5673, 5675, 5697,
       5700, 5702, 5703, 5704, 5714, 5720, 5723, 5727, 5728, 5735, 5787,
       5790, 5796, 5826, 5827, 6060, 6063, 6064, 6066, 6159, 6160, 6163,
       6168, 6178, 6186, 6188, 6199, 6399, 6400, 6419, 6423, 6464, 6476,
       6481, 6482, 6609, 6610, 6611, 6614, 6658, 6663, 6675, 6677, 7054,
       7055, 7058, 7062, 7069, 7077, 7078, 7079, 70

In [10]:
# neighbor cells of boundary edges in global grid: global cell index that should be in the halo of part0
halo_cells0 = grid.edge_face_connectivity[x, y]
halo_cells0.shape



(180,)

In [11]:
part0.cell_index.data

array([    2,     3,     4, ..., 20221, 20222, 20223], dtype=int32)

In [12]:
# ?? 
set(part0.cell_index.data).intersection(set(halo_cells0))


{2,
 6,
 9,
 37,
 1031,
 1071,
 1078,
 1079,
 1167,
 1209,
 1211,
 1213,
 1214,
 1215,
 1228,
 1230,
 1365,
 1502,
 1504,
 1507,
 1536,
 1538,
 1579,
 1583,
 1584,
 1595,
 1661,
 1667,
 1676,
 1692,
 1698,
 1711,
 1718,
 1732,
 1735,
 1741,
 1744,
 1748,
 1749,
 1754,
 1758,
 1783,
 1785,
 1788,
 1789,
 1877,
 4346,
 4347,
 4350,
 4351,
 4696,
 4700,
 5034,
 5040,
 5059,
 5064,
 5067,
 5081,
 5082,
 5088,
 5096,
 5098,
 5101,
 5114,
 7031,
 9557,
 10577}

In [13]:
# MERGING partitions again
import numpy as np
merged = xu.merge_partitions(partitions)
assert  np.allclose(merged["clat"], ds["clat"])





AssertionError: 

In [None]:
reordered = merged.ugrid.reindex_like(ds)
assert np.allclose(reordered["clat"], ds["clat"])


In [None]:
#partioning the grid only

grid = ds.ugrid.grid
grid_parts = grid.partition(n_part=4)

fig, axes = plt.subplots(2, 2, figsize=(12.6, 10))
for part, ax in zip(grid_parts, axes.ravel()):
    part.plot(ax=ax)

In [None]:
# partitioning labels
# returns an array which maps each cell to a partion number from [0, n_part-1]

labels = ds.ugrid.grid.label_partitions(n_part=4)
labels.ugrid.plot()
labels

In [None]:
print(f"{type(labels)}: {labels.sizes}")

In [None]:
#

In [None]:
#separate connected cmponents of a grid.
import xugrid
import xarray as xr

grid = ds.ugrid.grid

uda = xugrid.UgridDataArray(
    xr.DataArray(np.ones(grid.node_face_connectivity.shape[0]), dims=["face"]), grid
)
connected = uda.ugrid.connected_components()
connected.ugrid.plot(cmap="ocean")


In [None]:
# ds.grids returns a list with one element, which is unpacked below
uda0 = xugrid.UgridDataArray(
    xr.full_like(part0.obj["cell_area"], True, dtype=bool),
    part0.grid,
)

uda3 = xugrid.UgridDataArray(
    xr.full_like(partitions[3].obj["cell_area"], True, dtype=bool),
    partitions[3].grid,
)


In [None]:
halo0 = uda0.ugrid.binary_erosion(2)
halo3 = uda3.ugrid.binary_erosion(2)

In [None]:
np.sum(~halo0)



In [None]:
part0.obj["cell_index"][~halo0]

In [None]:
fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(20, 5))
halo0.ugrid.plot(ax=ax0)
halo3.ugrid.plot(ax=ax1)