In [None]:
%load_ext cython

In [1]:
import pandas as pd 
import torch 
import numpy as np
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import OptTensor
from torch_scatter import scatter
import multiprocessing as mp

graph = torch.load('sample_data/batch.pth')

HeteroGraphData(
  readout='edge',
  node={
    x=[150319, 2],
    n_id=[150319],
  },
  (node, to, node)={
    edge_index=[2, 706829],
    edge_attr=[706829, 5],
    y=[706829],
    timestamps=[706829],
    e_id=[706829],
    input_id=[8192],
    edge_label_index=[2, 8192],
    edge_label=[8192],
  },
  (node, rev_to, node)={
    edge_index=[2, 891736],
    edge_attr=[891736, 5],
    e_id=[891736],
  }
)

In [6]:
graph['node', 'rev_to', 'node'].num_edges

891736

In [None]:
res = np.lexsort((data[:, 2], data[:, 1]))

In [None]:
res[:100]

In [None]:
res.shape

In [None]:
%%cython

import cython

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)

def assign_ports(int[:,:] arr, int[:] ports):
    cdef: 
        int j = 0, prev_v = -1
        int counter = 0
        dict mapping = {}
    
    for u, v in arr:
        if v != prev_v:
            counter = 0
            mapping = {}
            mapping[u] = counter
            ports[j] = counter
        else:
            if u in mapping:
                ports[j] = mapping[u]
            else:
                counter += 1
                mapping[u] = counter
                ports[j] = counter
        
        prev_v = v
        j+=1
    
    return ports

In [None]:
import time 
for i in range(10):

    t1 = time.perf_counter()

    def process_ports(df, direction = ['source', 'target']):

        df = df.sort_values([direction[1], 't'])
        ports = np.zeros((graph['node', 'to', 'node'].edge_index.shape[1], ), dtype=np.int32)
        array = df[direction].to_numpy(dtype=np.int32)
        ports = assign_ports(array, ports)
        
        return torch.tensor(ports)[np.argsort(df.index)]


    edge_index = graph['node', 'to', 'node'].edge_index 
    df = pd.DataFrame(torch.cat([edge_index.T, graph.timestamps.reshape((-1,1))], dim=1).numpy().astype('int'), columns=['source', 'target', 't'])


    with mp.Pool(2) as pool:
        ports_1, ports_2 = pool.starmap(process_ports, [
            (df, ['source', 'target']),
            (df, ['target', 'source'])
        ])

    # port_1 = process_ports(df, ['source', 'target'])
    # port_2 = process_ports(df, ['target', 'source'])

    t2 = time.perf_counter()
    print(f"Retrieved data in {t2-t1:.2f}s")

In [None]:
ports

In [None]:
array.dtype

In [None]:
edge_index = graph['node', 'to', 'node'].edge_index
timestamp = graph.timestamps
edges = torch.cat([edge_index.T, timestamp.view(-1, 1)], dim=1).numpy().astype(np.int32)
edge_index = edge_index.numpy().astype(np.int32)
num_nodes = graph.num_nodes

In [None]:
def to_adj_nodes_with_times(num_nodes: int, edges: np.array):
    adj_list_out = dict([(i, []) for i in range(num_nodes)])
    adj_list_in = dict([(i, []) for i in range(num_nodes)])
    for u,v,t in edges:
        u,v,t = int(u), int(v), int(t)
        adj_list_out[u] += [(v, t)]
        adj_list_in[v] += [(u, t)]
    return adj_list_in, adj_list_out



def ports(edge_index, adj_list):
    ports = np.zeros((edge_index.shape[1], 1))
    ports_dict = {}
    for v, nbs in adj_list.items():
        if len(nbs) < 1: continue
        a = np.array(nbs)
        a = a[a[:, -1].argsort()]
        _, idx = np.unique(a[:,[0]],return_index=True,axis=0)
        nbs_unique = a[np.sort(idx)][:,0]
        for i, u in enumerate(nbs_unique):
            ports_dict[(u,v)] = i
    for i, e in enumerate(edge_index.T):
        ports[i] = ports_dict[tuple(e)]
    return ports


def add_ports(num_nodes, edge_index, edges):
    '''Adds port numberings to the edge features'''
    adj_list_in, adj_list_out = to_adj_nodes_with_times(num_nodes, edges)
    in_ports = ports(edge_index, adj_list_in)
    out_ports = [ports(edge_index[::-1, :], adj_list_out)] 
    return in_ports, out_ports



In [None]:
%%timeit 
add_ports(num_nodes, edge_index, edges)

In [None]:
%%cython

import cython 
from cython.parallel cimport prange
import numpy as np
cimport numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef to_adj_nodes_with_times(int num_nodes, int[:, ::1] edges):

    cdef:
        dict adj_list_out = {i: [] for i in range(num_nodes)}
        dict adj_list_in = {i: [] for i in range(num_nodes)}
        int u, v, t
        int num_edges = edges.shape[0]
    
    for i in prange(num_edges, nogil=True):
        u = edges[i, 0]
        v = edges[i, 1]
        t = edges[i, 2]
        adj_list_out[u].append((v, t))
        adj_list_in[v].append((u, t))
    
    return adj_list_in, adj_list_out


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cpdef ports(edge_index, dict adj_list):
    cdef: 
        int num_edges = edge_index.shape[1]
        np.ndarray[np.float64_t, ndim=2] ports = np.zeros((num_edges, 1), dtype=np.float64)
        dict ports_dict = {}
        int i, j
        np.ndarray[int, ndim=2] a
        np.ndarray[int, ndim=1] nbs_unique

    for v, nbs in adj_list.items():
        if len(nbs) < 1:
            continue
        a = np.array(nbs, dtype=np.int32)
        a = a[np.argsort(a[:, -1])]
        _, idx = np.unique(a[:, 0], return_index=True)
        nbs_unique = a[np.sort(idx)][:, 0]
        for j in range(len(nbs_unique)):
            ports_dict[(nbs_unique[j], v)] = j

    for i in range(num_edges):
        ports[i] = ports_dict[(edge_index[0, i], edge_index[1, i])]

    return ports


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def add_ports(int num_nodes, int[:, ::1] edge_index, int[:, ::1] edges):
    cdef: 
        dict adj_list_in, adj_list_out
        np.ndarray in_ports, out_ports

    # Call Cython functions to compute adjacency lists and ports
    adj_list_in, adj_list_out = to_adj_nodes_with_times(num_nodes, edges)
    in_ports = ports(edge_index, adj_list_in)
    out_ports = ports(edge_index[::-1, :], adj_list_out)

    return in_ports, out_ports

In [None]:
%%timeit 
add_ports(num_nodes, edge_index, edges)