# Speeding up Edge Contraction Algorithm

In [1]:
import sys

import numpy as np
import torch

sys.path.append("..")

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Roadmap

- [X] Create toy graph
- [X] Get timings of original algorithm
- [X] Implement CuGraph connected components
- [X] Get CC timings
- [ ] Implement CC into the PyGeometric function
- [ ] Explore a vectorized version of original idea - only one edge contracted per node
- [ ] Explore sorting vs. random choice of edge

### Toy Graph

In [3]:
num_nodes = 100000
num_edges = 1000000
x = torch.rand((num_nodes, 3), device=device).float()
e = torch.randint(0, len(x), (2, num_edges), device=device).long()
edge_score = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float()*0.4,
    torch.rand(int(num_edges*0.1), device=device).float()
])

### Original Algorithm

In [4]:
from torch_scatter import scatter_add
from torch_sparse import coalesce

In [56]:
def __merge_edges_original__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx].item()
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx].item()
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch

In [66]:
def __merge_edges__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx]
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx]
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch


In [114]:
%%time
new_x, new_edge_index, new_batch = __merge_edges__(x, e, torch.zeros(x.shape[0], device=device).long(), edge_score)

CPU times: user 6.27 s, sys: 15.6 ms, total: 6.29 s
Wall time: 6.29 s


### CuGraph Connected Components

In [71]:
import cugraph
import cudf
import pandas as pd
import cupy as cp
from torch.utils.dlpack import from_dlpack, to_dlpack

In [125]:
passing_edges = e[:, edge_score > 0.5]

In [109]:
%%time
passing_edges = cudf.from_dlpack(to_dlpack(passing_edges.T))

CPU times: user 0 ns, sys: 1.48 ms, total: 1.48 ms
Wall time: 1.22 ms


  res = libdlpack.from_dlpack(pycapsule_obj)


In [110]:
%%time
G = cugraph.Graph()
G.from_cudf_edgelist(passing_edges, source=0, destination=1, edge_attr=None)

CPU times: user 51.6 ms, sys: 3.87 ms, total: 55.5 ms
Wall time: 55.8 ms


In [111]:
%%time
labels = cugraph.components.connectivity.weakly_connected_components(G)

CPU times: user 8.61 ms, sys: 3.73 ms, total: 12.3 ms
Wall time: 11.5 ms


This all seems to work fine, so let's build it into a new method

#### TODO

In [66]:
def __merge_edges__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx]
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx]
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch


## One-node, One-edge Constraint

In [18]:
from torch_scatter import scatter_max

In [32]:
%%time
max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

CPU times: user 869 µs, sys: 0 ns, total: 869 µs
Wall time: 511 µs


In [20]:
stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
top_score = torch.argmax(stacked_score, dim=0)

In [21]:
max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)

In [22]:
max_indices[max_score_0 > max_score_1] = max_indices_0[max_score_0 > max_score_1]
max_indices[max_score_1 > max_score_0] = max_indices_1[max_score_1 > max_score_0]

In [23]:
max_indices

tensor([621952, 905298, 928899,  ..., 988551, 436959, 991108], device='cuda:0')

Get timing for comparison

In [17]:
%%time

max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
top_score = torch.argmax(stacked_score, dim=0)

max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)

max_indices[max_score_0 > max_score_1] = max_indices_0[max_score_0 > max_score_1]
max_indices[max_score_1 > max_score_0] = max_indices_1[max_score_1 > max_score_0]

CPU times: user 967 µs, sys: 2.25 ms, total: 3.22 ms
Wall time: 2.67 ms


In [28]:
%%time
torch.unique(stacked_indices, return_counts=True)

CPU times: user 1.46 ms, sys: 0 ns, total: 1.46 ms
Wall time: 973 µs


(tensor([     10,      23,      33,  ...,  999998,  999999, 1000000],
        device='cuda:0'),
 tensor([1, 1, 1,  ..., 2, 1, 7], device='cuda:0'))