In [1]:
import networkx as nx
import numpy as np
from numba.decorators import jit
from random import randint

In [2]:
# метод 0 - используем networkx для подсчета компонент
def method0(nv, edges):
    g = nx.Graph()
    g.add_nodes_from(range(nv))
    g.add_edges_from(edges)
    return nx.number_connected_components(g)

In [3]:
# метод 1 - вызываем тестируемую программу
def method1(nv, edges):
    rdata = [(nv, len(edges))] + [(a+1,b+1) for a,b in edges]
    sdata = '\n'.join('%s %s' % (a,b) for a, b in rdata) + '\n'
    pname = 'count-comps.py'
    from sys import executable
    from subprocess import run, STDOUT, PIPE
    res = run([executable, pname], stderr=STDOUT, stdout=PIPE, input=bytes(sdata,'utf8'))
    nc = int(res.stdout.strip().split()[0])
    return nc

In [4]:
# метод 2 - recursion and lists
def method2(nv, edges):
    verts = [[] for _ in range(nv)]
    for a, b in edges:
        verts[a].append(b)
        verts[b].append(a)

    visited = [False] * nv
    nc = 0
        
    def dfs(v):
        visited[v] = True
        for w in verts[v]:
            if not visited[w]:
                dfs(w)

    for i in range(nv):
        if not visited[i]:
            nc += 1
            dfs(i)
    
    return nc

In [5]:
# method3 - without recursion, using sets
def method3(nv, edges):
    adj = [set() for _ in range(nv)]
    for a, b in edges:
        adj[a].add(b)
        adj[b].add(a)

    visited = [False] * nv
    nc = 0

    for i in range(nv):
        if visited[i]:
            continue
        nc += 1
        visited[i] = True
        stack = [i]
        while stack:
            v = stack.pop()
            for w in adj[v]:
                if not visited[w]:
                    visited[w] = True
                    stack.append(w)

    return nc

In [6]:
# method4 - using numpy
@jit
def method4(nv, edges):
    adj = np.zeros([nv,nv], dtype=np.bool8)
    for a, b in edges:
        adj[a,b] = adj[b,a] = True

    visited = np.zeros([nv], dtype=np.bool8)
    stack = np.zeros([nv], dtype=np.int32)
    nc = 0

    for i in range(nv):
        if visited[i]:
            continue
        nc += 1
        visited[i] = True
        ns = 0
        v = i
        while True:
            for w in range(nv):
                if adj[v,w] and not visited[w]:
                    visited[w] = True
                    stack[ns] = w
                    ns += 1
            if not ns:
                break
            ns -= 1
            v = stack[ns]

    return nc

# pre-compile:
_ = method4(5, [(i,i) for i in range(5)])
print('ok')

ok


In [7]:
sdata1 = '''4 2
1 2
3 2
'''
sdata2 = '''4 3
1 2
3 2
4 3
'''

sdata3 = '''6 0
6 6
1 1
2 2
3 3
4 4
5 1
1 5
3 2
4 3
2 1
'''

sdata4 = '''0 0
1 1
1 1
1 1
1 1
'''

In [8]:
# проверка на небольшом наборе тестовых данных
for case, sdata in enumerate([sdata1, sdata2, sdata3, sdata4]):
    rdata = [[int(x) for x in s.split()] for s in sdata.splitlines()]
    nv, ne = rdata[0]
    edges = rdata[1:]

    ne = len(edges)
    if edges and nv <= 0:
        nv = max([max(a,b) for a,b in edges])
    assert(all([1<=a<=nv and 1<=b<=nv for a,b in edges]))
    edges = [(a-1, b-1) for a,b in edges]

    nc0 = method0(nv, edges)

    nc = method1(nv, edges)
    assert nc == nc0

    nc = method2(nv, edges)
    assert nc == nc0

    nc = method3(nv, edges)
    assert nc == nc0

    nc = method4(nv, edges)
    assert nc == nc0

print('ok')

ok


In [9]:
# проверка на большом массиве случайных графов
max_nodes = 100
max_graphs = 10000

_rand_graphs = (nx.gnm_random_graph(randint(1,max_nodes), randint(1,max_nodes))
                for _ in range(max_graphs))
_all_args = ((len(g.nodes), list(g.edges)) for g in _rand_graphs)
all_tests = [(args, method0(*args)) for args in _all_args]

In [76]:
%timeit -n1 -r1 assert all(method1(*args) == ret for args, ret in all_tests[:10])
print('ok1')

%timeit -n1 -r1 assert all(method2(*args) == ret for args, ret in all_tests)
print('ok2')

%timeit -n1 -r1 assert all(method3(*args) == ret for args, ret in all_tests)
print('ok3')

%timeit -n1 -r1 assert all(method4(*args) == ret for args, ret in all_tests)
print('ok4')

2.05 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ok1
1.61 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ok2
2.23 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ok3
1.03 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ok4
