In [1]:
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

In [2]:
adjacency_matrix = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float)
print(adjacency_matrix)

tensor([[0., 1., 0.],
        [1., 0., 1.],
        [0., 1., 0.]])


In [3]:
csr_from_adjacency = csr_matrix(adjacency_matrix)
print(csr_from_adjacency)

<Compressed Sparse Row sparse matrix of dtype 'float32'
	with 4 stored elements and shape (3, 3)>
  Coords	Values
  (0, 1)	1.0
  (1, 0)	1.0
  (1, 2)	1.0
  (2, 1)	1.0


In [4]:
edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]], dtype=torch.long)
print(edge_index)

tensor([[0, 1, 1, 2, 4, 5],
        [1, 0, 2, 1, 5, 4]])


In [5]:
data = torch.ones_like(edge_index[0])
print(data)

tensor([1, 1, 1, 1, 1, 1])


In [6]:
csr_from_edge_index = csr_matrix((data.numpy(), edge_index.numpy()), shape=(8, 8))
print(csr_from_edge_index)

<Compressed Sparse Row sparse matrix of dtype 'int64'
	with 6 stored elements and shape (8, 8)>
  Coords	Values
  (0, 1)	1
  (1, 0)	1
  (1, 2)	1
  (2, 1)	1
  (4, 5)	1
  (5, 4)	1


In [7]:
n_components, labels = connected_components(
    csgraph=csr_from_edge_index, directed=False, return_labels=True
)
print(n_components)
print(labels)

5
[0 0 0 1 2 2 3 4]
