# temporal network

> utilities holding temporal network info

In [None]:
#| default_exp tnet

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import numpy as np
from numba import prange
from moraine.utils_ import ngpjit, ngjit

In [None]:
#| export
def _imagepair_from_bandwidth(nimages,bandwidth=None):
    if bandwidth is None: bandwidth = nimages
    assert nimages >= bandwidth
    ref, sec = np.triu_indices(nimages, 1)
    idx = np.where((sec-ref)<=bandwidth)
    return np.stack((ref[idx], sec[idx]),axis=-1).astype(np.int32)

In [None]:
#| export
@ngpjit
def are_edges_sorted(edges):
    num_edges = edges.shape[0]
    for i in prange(num_edges - 1):
        if edges[i, 0] > edges[i + 1, 0] or (edges[i, 0] == edges[i + 1, 0] and edges[i, 1] > edges[i + 1, 1]):
            return False
    return True

In [None]:
#| export
@ngpjit
def are_edges_directed(edges):
    '''directed = reference index smaller than secondary index'''
    num_edges = edges.shape[0]
    for i in prange(num_edges):
        if edges[i, 0] >= edges[i, 1]:
            return False
    return True

In [None]:
#| export
@ngjit
def are_edges_connected(edges):
    num_edges = edges.shape[0]

    current_ref = -1
    for i in range(num_edges):
        if edges[i,0] == current_ref:
            continue
        else:
            if edges[i,0] != edges[i,1]-1:
                return False
            current_ref = edges[i,0]
    return True

In [None]:
#| export
class TempNet(object):
    def __init__(self,image_pairs,check_if_valid=True):
        if check_if_valid:
            if not are_edges_sorted:
                raise ValueError('input image pairs are not sorted.')
            if not are_edges_directed:
                raise ValueError('input image pairs are not directed (reference index larger than or equal to secondary index).')
            if not are_edges_connected:
                raise ValueError('input image pairs are not connected.')
        self.image_pairs = image_pairs.astype(np.int32)

    @classmethod
    def from_bandwidth(cls, nimages, bandwidth=None):
        image_pairs = _imagepair_from_bandwidth(nimages,bandwidth)
        return cls(image_pairs,check_if_valid=False)

    def save(self,path:str, # zarr path
            ):
        '''Save the TempNet.'''
        tempnet_zarr = zarr.open(path,'w',shape=self.image_pairs.shape,dtype=self.image_pairs.dtype)
        tempnet_zarr[:] = self.image_pairs[:]

    @classmethod
    def load(cls, zarr_path:str, # zarr path
            ):
        '''classmethod to load the saved HilbertRtree.'''
        tempnet_zarr = zarr.open(path,'r')
        return cls(tempnet_zarr[:], check_if_valid=False)

Usage:

Create a temporal network by specifing the number of images and the temporal bandwidth (the maximum number of images to be paired with one image):

In [None]:
tnet = TempNet.from_bandwidth(5,2)

In [None]:
tnet.image_pairs

array([[0, 1],
       [0, 2],
       [1, 2],
       [1, 3],
       [2, 3],
       [2, 4],
       [3, 4]], dtype=int32)

Or, you can specify your own image pairs:

In [None]:
tnet = TempNet(np.stack(([0,1,2,3],[1,2,3,4]),axis=-1))
tnet.image_pairs

array([[0, 1],
       [1, 2],
       [2, 3],
       [3, 4]], dtype=int32)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()