In [1]:
import time
from collections import defaultdict
import numpy as np
from pynextsim import NextsimBin

In [2]:
class TriNeighbors:
    neighbors = None
    def __init__(self, t):
        elem2edge, edge2elem = self.get_edge_elem_relationship(t)
        # keep nearest neighbors for each element
        self.neighbors = []
        for i in range(t.shape[0]):
            neighbors_lists = [edge2elem[edge] for edge in elem2edge[i] if edge in edge2elem]
            neighbors_i = []
            for n1 in neighbors_lists:
                if len(n1) != 1:
                    for n2 in n1:
                        if n2 != i:
                            neighbors_i.append(n2)
            self.neighbors.append(neighbors_i)
        self.nneighbors = [len(n) for n in self.neighbors]
        
    def get_edge_elem_relationship(self, t):
        """ Create to maps: element to edge, and edge to element in one pass"""
        elem2edge = []
        edge2elem = defaultdict(list)
        for i, elem in enumerate(t):
            jj = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[0])]
            edges = [tuple(sorted(j)) for j in jj]
            elem2edge.append(edges)
            for edge in edges:
                edge2elem[edge].append(i)
        return elem2edge, edge2elem
    
    def get_neighbors(self, i, n=1, e=()):
        """ Get neighbors of element <i> crossing <n> edges

        Parameters
        ----------
        i : int, index of element
        n : int, number of edges to cross
        e : (int,), indeces to exclude

        Returns
        -------
        l : list, List of unique inidices of existing neighbor elements

        """
        # return list of existing immediate neigbours
        if n == 1:
            return self.neighbors[i]
        # recursively return list of neighbors after 1 edge crossing
        n2 = []
        for j in self.neighbors[i]:
            if j not in e:
                n2.extend(self.get_neighbors(j, n-1, e+(i,)))
        return list(set(self.neighbors[i] + n2))

    def get_neighbors_many(self, indices, n=1):
        """ Group neighbours of several elements """
        neighbors_many = [self.get_neighbors(i, n=n) for i in indices]
        return np.unique(np.hstack(neighbors_many)).astype(int)
    
    def get_distance_to_border(self):
        dist = np.zeros(len(self.neighbors)) + np.nan
        border = np.where(np.array(self.nneighbors) < 3)[0]
        dist[border] = 0

        d = 1
        while np.any(np.isnan(dist)):
            for i in np.where(dist == d - 1)[0]:
                neibs = self.get_neighbors(i)
                for j in neibs:
                    if np.isnan(dist[j]):
                        dist[j] = d
            d += 1
        return dist

In [3]:
t = np.load('field_20230101T000000Z.npz')['t']
t.shape

(138086, 3)

In [4]:
tn = TriNeighbors(t)
%timeit TriNeighbors(t)


774 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
# only nearest neighbors (distance = 1 edge)
tn.get_neighbors(100, 1)

[102, 25758, 25764]

In [6]:
# all neighbors at distance <= 3 edges
tn.get_neighbors(100, 3)

[25792,
 25761,
 25791,
 135075,
 25764,
 100,
 102,
 135076,
 25767,
 25760,
 101,
 133805,
 135151,
 133848,
 25756,
 25757,
 25758,
 25759]

In [7]:
def get_neighbors(i, n=1):
    """ Get <n> neighbours for elemenets with indices in <i> """
    nn = []
    for j in i:
        nn.append(tn.get_neighbors(j))
    return nn

In [8]:
%timeit get_neighbors([10])

342 ns ± 1.46 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [9]:
%timeit get_neighbors(range(10))

2.17 µs ± 12 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
%timeit get_neighbors(range(100))

19.5 µs ± 42.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [11]:
%timeit get_neighbors(range(100000))

19.6 ms ± 18 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
# get all neighbors at a distance <= 3 edges
%timeit get_neighbors(range(100000), 3)

19.6 ms ± 3.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
