In [None]:
import numpy as np
import xarray as xr

In [None]:
fix_boundaries = False
mask_bad_area = True

In [None]:
grid_ds = xr.open_dataset('fram_data/norfjords_160m_grid_version2+++.nc')

## Fix mask_rho variable - boundaries check

In [None]:
def find_matches(border_position: str):
    """
    find indices of matches
    """
    up_eta_rho = grid_ds.dims['eta_rho']
    up_xi_rho = grid_ds.dims['xi_rho']
    
    border = {
        'west': grid_ds.mask_rho.isel(xi_rho=slice(0, 2)),
        'north': grid_ds.mask_rho.isel(eta_rho=slice(up_eta_rho-2, up_eta_rho)),  # grid_ds.mask_rho[-2:, :]
        'east': grid_ds.mask_rho.isel(xi_rho=slice(up_xi_rho-2, up_xi_rho)),
        'south': grid_ds.mask_rho.isel(eta_rho=slice(0, 2)),
    }
    
    pattern = {
        'west': np.array([[1., 0.]]),
        'north': np.array([[0.], [1.]]),
        'east': np.array([[0., 1.]]),
        'south': np.array([[1.], [0.]]),
    }

    if border_position in ('north', 'south'):
        idx = np.argwhere(np.asarray(border[border_position].values == pattern[border_position], 
                          dtype=np.int32).transpose() @ np.array([1, 1]) > 1)
    elif border_position in ('west', 'east'):
        idx = np.argwhere(np.asarray(border[border_position].values == pattern[border_position], 
                          dtype=np.int32) @ np.array([1, 1]) > 1)
    else:
        raise ValueError

    print(f"{border_position} indices: {idx.transpose()[0]}")
    return idx


In [None]:
def get_values(eta_rho, xi_rho):
    try:
        return grid_ds.mask_rho.isel(eta_rho=eta_rho, xi_rho=xi_rho).values
    except IndexError:
        return np.NAN


def print_values(west_idx, north_idx, east_idx, south_idx):
    up_eta_rho = grid_ds.dims['eta_rho']
    up_xi_rho = grid_ds.dims['xi_rho']

    print(f"West values: {get_values(eta_rho=west_idx.squeeze(), xi_rho=0)}")
    print(f"North values: {get_values(eta_rho=up_eta_rho-1, xi_rho=north_idx.squeeze())}")
    print(f"East values: {get_values(eta_rho=east_idx.squeeze(), xi_rho=up_xi_rho-1)}")
    print(f"South values: {get_values(eta_rho=0, xi_rho=south_idx.squeeze())}")


In [None]:
if fix_boundaries:
    west_idx = find_matches('west')
    north_idx = find_matches('north')
    east_idx = find_matches('east')
    south_idx = find_matches('south')
    
    print_values(west_idx, north_idx, east_idx, south_idx)
    
    # Get xi and eta coordinates to use loc assignment (isel or iloc won't work)
    xi_rho_west_coords = grid_ds.coords['xi_rho'][0]
    eta_rho_west_coords = grid_ds.coords['eta_rho'][west_idx.squeeze()] * xr.ones_like(xi_rho_west_coords)
    print(f"Xi west coordinates: {xi_rho_west_coords}\n",
          f"Eta west coordinates: {eta_rho_west_coords}\n")
    
    xi_rho_north_coords = grid_ds.coords['xi_rho'][north_idx.squeeze()]
    eta_rho_north_coords = grid_ds.coords['eta_rho'][-1] * xr.ones_like(xi_rho_north_coords)
    print(f"Xi north coordinates: {xi_rho_north_coords}\n",
          f"Eta north coordinates: {eta_rho_north_coords}\n")
    
    xi_rho_east_coords = grid_ds.coords['xi_rho'][-1]
    eta_rho_east_coords = grid_ds.coords['eta_rho'][east_idx.squeeze()] * xr.ones_like(xi_rho_east_coords)
    print(f"Xi east coordinates: {xi_rho_east_coords}\n",
          f"Eta east coordinates: {eta_rho_east_coords}\n")
    
    xi_rho_south_coords = grid_ds.coords['xi_rho'][south_idx.squeeze()]
    eta_rho_south_coords = grid_ds.coords['eta_rho'][0] * xr.ones_like(xi_rho_south_coords)
    print(f"Xi south coordinates: {xi_rho_south_coords}\n",
          f"Eta south coordinates: {eta_rho_south_coords}\n")

    # mask the corresponding points in mask_rho
    coords = (
        (xi_rho_west_coords, eta_rho_west_coords),
        (xi_rho_north_coords, eta_rho_north_coords),
        (xi_rho_east_coords, eta_rho_east_coords),
        (xi_rho_south_coords, eta_rho_south_coords),
    )

    for coord in coords:
        grid_ds.mask_rho.loc[dict(xi_rho=coord[0], eta_rho=coord[1])] = 0
    
    print_values(west_idx, north_idx, east_idx, south_idx)

## Mask out the particular points, which cause instabilities

In [None]:
if mask_bad_area:
    grid_ds.mask_rho.isel(xi_rho=slice(1434, None), eta_rho=slice(685, 698)).plot(figsize=(14, 7))
    grid_ds.mask_rho.values[685:698, 1434:] = 0
    grid_ds.mask_rho.isel(xi_rho=slice(1434, None), eta_rho=slice(685, 698)).plot(figsize=(14, 7))


## Rewrite other masks

Andre matlab code:

% mask at u, v and psi points
```
mask_u = mask_rho(:,1:end-1).*mask_rho(:,2:end);
mask_v = mask_rho(1:end-1,:).*mask_rho(2:end,:);
mask_psi = ...
    mask_rho(1:end-1,1:end-1).*mask_rho(1:end-1,2:end).*...
    mask_rho(2:end,1:end-1).*mask_rho(2:end,2:end);
``` 

In [None]:
left_mask = grid_ds.mask_rho.isel(xi_rho=slice(1, None)).values  # left border
right_mask = grid_ds.mask_rho.isel(xi_rho=slice(None, -1)).values  # Right border
u_mask = left_mask * right_mask
grid_ds.mask_u.values = u_mask

In [None]:
bottom_mask = grid_ds.mask_rho.isel(eta_rho=slice(1, None)).values  # bottom border
upper_mask = grid_ds.mask_rho.isel(eta_rho=slice(None, -1)).values  # upper border
v_mask = bottom_mask * upper_mask
grid_ds.mask_v.values = v_mask

In [None]:
rho_mask = grid_ds.mask_rho.values
psi_mask = (
    rho_mask[:-1, :-1] * rho_mask[:-1, 1:] *
    rho_mask[1:, :-1] * rho_mask[1:, 1:]
)
grid_ds.mask_psi.values = psi_mask

In [None]:
print(rho_mask.shape, u_mask.shape, v_mask.shape, psi_mask.shape)

## Save

In [None]:
grid_ds.to_netcdf(path='fram_data/norfjords_160m_grid_version3.nc', format='NETCDF4')

In [None]:
grid_ds.mask_rho.isel(eta_rho=slice(953, 958), xi_rho=slice(1435, 1440))

In [None]:
grid_ds.mask_u.isel(eta_u=slice(953, 958), xi_u=slice(1435, 1440))

In [None]:
grid_ds.mask_v.isel(eta_v=slice(953, 958), xi_v=slice(1435, 1440))

In [None]:
grid_ds.mask_psi.isel(eta_psi=slice(953, 958), xi_psi=slice(1435, 1440))