In [12]:
import mlx.core as mx
from torch_geometric.utils.sparse import index2ptr
from torch_geometric.datasets import Planetoid
from scipy.sparse import csr_matrix
from torch_geometric.nn import Node2Vec
import torch
from torch_geometric.utils import is_undirected

A unifrom sampling algorithm to create random walks

In [13]:
dataset = Planetoid(root ="data/Cora", name='Cora')

Using undirected graphs for random walks to keep it simple

In [14]:
is_undirected(dataset.edge_index)

True

convert the coo matrix into csr and then verify with numba

In [15]:
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils import sort_edge_index
from torch_geometric.utils.sparse import index2ptr
from torch.utils.data import DataLoader

In [16]:
data, edge_index = dataset.x, dataset.edge_index

In [17]:
!pip install torch_cluster



In [18]:
from torch_cluster.rw import random_walk

In [19]:
num_nodes = maybe_num_nodes(edge_index=edge_index)
loader = DataLoader(range(num_nodes), batch_size=1000)

In [20]:
start = next(iter(loader))

In [21]:
import time

In [22]:
start_time = time.time()
num_nodes  = maybe_num_nodes(edge_index=edge_index)
row, col = sort_edge_index(edge_index=edge_index, num_nodes=num_nodes)
row_ptr, col = index2ptr(row, num_nodes), col
print(type(col), type(row_ptr))
random_walks = torch.ops.torch_cluster.random_walk(row_ptr, col, start, 1000, 1.0, 1.0)
print("Time taken to perform 1000 random walks with Torch_cluster is", time.time()-start_time)

<class 'torch.Tensor'> <class 'torch.Tensor'>
Time taken to perform 1000 random walks with Torch_cluster is 0.05887198448181152


Torch_cluster is insanely fast and has really good performance even on CPU

In [23]:
import numpy as np

Create own random walk algorithm and measure the time taken on CPU

In [24]:
def random_walk(row_ptr, col, start, walk_length):
    """
    Computes random walks of length `walk_length` starting from node indices `start` in the
    graph given by `(row_ptr, col)` as adjacency matrix in compressed sparse row (CSR) format.

    Args:
        row_ptr (LongTensor): Row pointers of the adjacency matrix in CSR format.
        col (LongTensor): Column indices of the adjacency matrix in CSR format.
        start (LongTensor): Indices of starting nodes for random walks.
        walk_length (int): Length of random walks.

    Returns:
        LongTensor: Tensor of shape `(num_starts, walk_length)` containing the nodes indices
        of the random walks.
    """

    start = start.flatten()
    num_starts = start.shape[0]
    out = np.empty((num_starts, walk_length), dtype= col.dtype)
    for l in range(walk_length):
        if l == 0:
            out[:, l] = start
        else:
            prev = out[:, l - 1]
            
            prev_nbrs_start = row_ptr[prev]
            prev_nbrs_end = row_ptr[prev + 1]
            
            prev_nbrs = [col[start_idx:end_idx] for start_idx, end_idx in zip(prev_nbrs_start, prev_nbrs_end)]
            # Generate random neighbor indices
            rand_idx = [np.random.randint(0, len(nbrs)) for nbrs in prev_nbrs]
            # Get the corresponding neighbors
            next_nbrs = [nbrs[idx] for nbrs, idx in zip(prev_nbrs, rand_idx)]
            out[:, l] = next_nbrs

    return out

In [25]:
start_time = time.time()
num_nodes = maybe_num_nodes(edge_index=edge_index)
row, col = sort_edge_index(edge_index=edge_index, num_nodes=num_nodes)
row_numpy = row.numpy()
unique_vals, counts = np.unique(row_numpy, return_counts=True)
row_ptr_numpy = np.cumsum(counts)
row_ptr_numpy = np.insert(row_ptr_numpy, 0, 0)
random_walk(row_ptr_numpy, col.numpy(), start.numpy(), walk_length=1000)
print("time taken for random walks using numpy is ", time.time()-start_time)

time taken for random walks using numpy is  1.4987566471099854


Running simulations for mlx

In [26]:
from mlx_graphs.datasets import PlanetoidDataset
from mlx_graphs.utils.sorting import sort_edge_index

In [27]:
cora_dataset = PlanetoidDataset(name='cora', base_dir="~")

Loading cora data ... Done


In [28]:
edge_index = cora_dataset.graphs[0].edge_index

In [33]:
def random_walk_mlx(row_ptr:mx.array, col: mx.array, start:mx.array, walk_length: int):
    row_ptr_numpy = np.array(row_ptr, copy = False)
    col_numpy = np.array(col, copy= False)
    num_starts = start.shape[0]
    out = np.zeros((num_starts, walk_length), dtype = col_numpy.dtype)
    for l in range(walk_length):
        if l == 0:
            out[:, l] = start
        else:
            prev = out[:, l - 1]
            prev_nbrs_start = row_ptr_numpy[prev]
            prev_nbrs_end = row_ptr_numpy[prev + 1]
            
            prev_nbrs = [col_numpy[start_idx:end_idx] for start_idx, end_idx in zip(prev_nbrs_start, prev_nbrs_end)]
            # Generate random neighbor indices
            rand_idx = [np.random.randint(0, len(nbrs)) for nbrs in prev_nbrs]
            # Get the corresponding neighbors
            next_nbrs = [nbrs[idx] for nbrs, idx in zip(prev_nbrs, rand_idx)]
            out[:, l] = next_nbrs
    
    return out

In [39]:
start_time  = time.time()
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy = False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
row_ptr_mlx = mx.concatenate([mx.array([0]),mx.array(cum_sum_mlx)])
random_walk_mlx(row_ptr_mlx, col_mlx, start=start, walk_length=1000)
print("Time taken by mlx and numpy is ", time.time()-start_time)

Time taken by mlx and numpy is  1.8042528629302979
