In [None]:
import os
from matplotlib.tri import Triangulation
import matplotlib.pyplot as plt
from pynextsim import NextsimBin
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree

from tqdm.notebook import tqdm

from lmsiage.utils import compute_mapping, get_area_ratio, get_same_elements
%matplotlib widget


In [None]:
def nextsimbin2tri(restart_file, maskfile='mask.npy'):
    min_x = -2.5e6
    max_y = 2.1e6
    res = 20000

    # read raw data
    n = NextsimBin(restart_file)
    tri = Triangulation(n.mesh_info.nodes_x, n.mesh_info.nodes_y, n.mesh_info.indices)
    ids = n.mesh_info.get_var('id')
    
    # read mask
    mask = np.load('mask.npy')
    # remove elements outside mask
    x_el = tri.x[tri.triangles].mean(axis=1)
    y_el = tri.y[tri.triangles].mean(axis=1)
    cols_el = (x_el - min_x) / res
    rows_el = (max_y - y_el) / res
    cols_el = np.clip(cols_el, 0, mask.shape[1]-1).astype(int)
    rows_el = np.clip(rows_el, 0, mask.shape[0]-1).astype(int)
    el_mask = mask[rows_el, cols_el]

    sub_tri = tri.triangles[el_mask]
    uniq_nodes, uniq_inv = np.unique(sub_tri, return_inverse=True)
    newx = tri.x[uniq_nodes]
    newy = tri.y[uniq_nodes]
    newi = np.arange(uniq_nodes.size)
    newt = newi[uniq_inv]
    new_ids = ids[uniq_nodes]

    tri = Triangulation(newx, newy, newt)
    ids = new_ids
    return tri, ids

def get_tri_a_from_nextsim(tri_0, ids_0, tri_o, ids_o, N=10):
    new_x = np.zeros_like(tri_0.x) + np.nan
    new_y = np.zeros_like(tri_0.y) + np.nan
    _, idx0, idx1 = np.intersect1d(ids_0, ids_o, return_indices=True)
    new_x[idx0] = tri_o.x[idx1]
    new_y[idx0] = tri_o.y[idx1]

    points = np.column_stack((tri_0.x, tri_0.y))
    tree = cKDTree(points)
    lost_node_ids = list(set(ids_0) - set(ids_o))
    for lost_node_id in lost_node_ids:
        lost_node_index = np.nonzero(ids_0 == lost_node_id)[0]
        lost_node_x = tri_0.x[lost_node_index]
        lost_node_y = tri_0.y[lost_node_index]

        distances, neib_indices = tree.query(points[lost_node_index], k=N)
        neib_ids = ids_0[neib_indices]
        common_neib_ids, idx0, idx1 = np.intersect1d(neib_ids, ids_o, assume_unique=True, return_indices=True)
        if len(common_neib_ids) == 0:
            distances, neib_indices = tree.query(points[lost_node_index], k=N*2)
            neib_ids = ids_0[neib_indices]
            common_neib_ids, idx0, idx1 = np.intersect1d(neib_ids, ids_o, assume_unique=True, return_indices=True)
            if len(common_neib_ids) == 0:
                distances, neib_indices = tree.query(points[lost_node_index], k=N*5)
                neib_ids = ids_0[neib_indices]
                common_neib_ids, idx0, idx1 = np.intersect1d(neib_ids, ids_o, assume_unique=True, return_indices=True)
                if len(common_neib_ids) == 0:
                    distances, neib_indices = tree.query(points[lost_node_index], k=N*10)
                    neib_ids = ids_0[neib_indices]
                    common_neib_ids, idx0, idx1 = np.intersect1d(neib_ids, ids_o, assume_unique=True, return_indices=True)

        nearest_neib_subindex = np.argmin(distances[:, idx0])

        nearest_neib_index0 = neib_indices[:, idx0[nearest_neib_subindex]]
        nearest_neib_index1 = idx1[nearest_neib_subindex]

        nn_node_x0 = tri_0.x[nearest_neib_index0]
        nn_node_y0 = tri_0.y[nearest_neib_index0]

        dx = lost_node_x - nn_node_x0
        dy = lost_node_y - nn_node_y0

        new_x[lost_node_index] = tri_o.x[nearest_neib_index1] + dx
        new_y[lost_node_index] = tri_o.y[nearest_neib_index1] + dy

    tri_a = Triangulation(new_x, new_y, triangles=tri_0.triangles)
    return tri_a

def compute_mapping_fast(tri_a, tri_o):
    xa, ya, ta = tri_a.x, tri_a.y, tri_a.triangles
    xo, yo, to = tri_o.x, tri_o.y, tri_o.triangles
    same_in_a, same_in_o = get_same_elements(xa[ta], xo[to])
    new_elems = np.ones(tri_o.triangles.shape[0], dtype=bool)
    new_elems[same_in_o] = False
    new_elem_idx = np.nonzero(new_elems)[0]

    src2dst = np.column_stack([same_in_a, same_in_o]).tolist()
    weights = [1] * len(src2dst)
    elem_xa = xa[ta].mean(axis=1)
    elem_ya = ya[ta].mean(axis=1)

    xo_elem = xo[to[new_elem_idx]].mean(axis=1)
    yo_elem = yo[to[new_elem_idx]].mean(axis=1)

    tree_a = cKDTree(np.column_stack([elem_xa, elem_ya]))

    _, idx = tree_a.query(np.column_stack([xo_elem, yo_elem]), k=1)

    src2dst += np.column_stack([idx, new_elem_idx]).tolist()
    weights += [1] * len(idx)    
    src2dst = np.array(src2dst)
    weights = np.array(weights)

    return src2dst, weights


In [None]:
res_dir = 'restarts'
sia_dir = 'siage'

duration = 365
date_start = pd.Timestamp("1991-01-02")
dates = [date_start + pd.Timedelta(days=i) for i in range(duration)]
restart_files = [f"{res_dir}/{date.strftime('%Y%m%d')}/inputs/field_{date.strftime('%Y%m%dT000000Z')}.bin" for date in dates]

In [None]:
tri_0, ids_0 = nextsimbin2tri(restart_files[0])

for date, restart_file in tqdm(zip(dates[1:], restart_files[1:]), total=len(dates)-1):
    mesh_dst_file = f"{sia_dir}/mesh/mesh_{date.strftime('%Y%m%d')}.npz"
    if os.path.exists(mesh_dst_file):
        with np.load(mesh_dst_file, allow_pickle=True) as f:
            tri_o = Triangulation(f['x'], f['y'], f['t'])
            ids_o = f['ids']
    else:
        print(restart_file, mesh_dst_file)
        tri_o, ids_o = nextsimbin2tri(restart_file)
        tri_a = get_tri_a_from_nextsim(tri_0, ids_0, tri_o, ids_o, N=10)
        #src2dst, weights = compute_mapping_fast(tri_a, tri_o)
        src2dst, weights = compute_mapping(tri_a, tri_o, 15000, cores=6)

        area_ratio = get_area_ratio(tri_0, tri_a, tri_o, src2dst, weights)
        np.savez(mesh_dst_file, x=tri_o.x, y=tri_o.y, t=tri_o.triangles, src2dst=src2dst, weights=weights, ar=area_ratio, ids=ids_o)

    tri_0 = tri_o
    ids_0 = ids_o


In [None]:
tri_0, ids_0 = nextsimbin2tri(restart_files[0])


In [None]:
plt.triplot(tri_0)

In [None]:
min_x = -2.5e6
max_x = 2.5e6
min_y = -2.9e6
max_y = 2.1e6
res = 20000
cols = (tri_0.x - min_x) / res
rows = (max_y - tri_0.y) / res
cols = np.clip(cols, 0, cols.max()).astype(int)
rows = np.clip(rows, 0, rows.max()).astype(int)
a = np.zeros((rows.max() + 1, cols.max() + 1))
a[rows, cols] += 1
#plt.imsave('mask.png', a, cmap='gray')

In [None]:
mask = plt.imread('~/Downloads/mask.png')[:,:,0]

In [None]:
from scipy.ndimage import maximum_filter
mask = maximum_filter(mask, size=3).astype(bool)

In [None]:
tri_0, ids_0 = nextsimbin2tri(restart_files[0])

In [None]:
with np.load('../2020/mesh_20200102.npz', allow_pickle=True) as ds:
    x = ds['x']
    y = ds['y']
    t = ds['t']
    ids = ds['ids']
    ar = ds['ar']
    src2dst = ds['src2dst']
    weights = ds['weights']
tri_o = Triangulation(x, y, t)

In [None]:
plt.figure()
plt.tripcolor(tri_o, ar, cmap='bwr', clim=[0.9, 1.1])
plt.show()