In [1]:
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
import numpy as np
import itertools
import random
from random import randint
import sys
import networkx as nx
import matplotlib.pyplot as plt
import copy

In [2]:
def min_max_itr(adj_mat):
    ret_mat = adj_mat
    num_vertex = len(adj_mat)
    for i in range(num_vertex):
        for j in range(num_vertex):
            temp_val = float('inf')
            for k in range(num_vertex):
                temp_val = min(temp_val,max(adj_mat[i][k],adj_mat[k][j]))
            ret_mat[i][j] = min(ret_mat[i][j],temp_val)
    return ret_mat

def mst_test(num_vert, unique = 0, vis = 0, seed = 2021 , verb = 0, scale = 100):
    random.seed(seed)
    if unique and scale < 5 * num_vert: raise Exception('scale too small')
    source_graph = [[0 for _ in range(num_vert)] for _ in range(num_vert)]
    weight_set = set({})
    for i in range(num_vert):
        for j in range(i,num_vert):
            if i == j:  source_graph[i][j] = 0
            else:
                if randint(0,10) > 5:
                    if not unique:
                        source_graph[i][j] = randint(1,scale)
                    else:
                        rand_weight = randint(1,scale)
                        while rand_weight in weight_set: rand_weight = randint(1,scale)
                        weight_set.add(rand_weight)
                        source_graph[i][j] = rand_weight
                else: 
                    source_graph[i][j] = 0   

    # init sym graph
    source_graph_sym = [[0 for _ in range(num_vert)] for _ in range(num_vert)]
    for i in range(num_vert):
        for j in range(i,num_vert):
            source_graph_sym[i][j] = source_graph[i][j]
            source_graph_sym[j][i] = source_graph[i][j]

    # reference mst
    csr_graph = csr_matrix(source_graph)
    msp_graph = minimum_spanning_tree(csr_graph)
    msp_graph = msp_graph.toarray().astype(int)

    # parse adj graph, 0 -> inf
    parsed_adj_mat = [[0 for _ in range(num_vert)] for _ in range(num_vert)] 
    for i, j in itertools.product(range(num_vert), range(num_vert)): 
        if source_graph_sym[i][j] == 0 and i != j: parsed_adj_mat[i][j] = float('inf')
        else: parsed_adj_mat[i][j] = source_graph_sym[i][j]

    # run proposed algorithm
    prev_res = None
    iter_count = 0
    while prev_res != parsed_adj_mat:
        iter_count += 1
        prev_res = copy.deepcopy(parsed_adj_mat)
        parsed_adj_mat = min_max_itr(parsed_adj_mat)
    if verb:
        print(iter_count,'iteration taken')

    # while iter_count < num_vert:
    #   parsed_adj_mat = min_max_itr(parsed_adj_mat)
    #   iter_count += 1
    

    # post processing
    for i, j in itertools.product(range(len(parsed_adj_mat)), range(len(parsed_adj_mat))): 
        if parsed_adj_mat[i][j] != source_graph[i][j]: 
            parsed_adj_mat[i][j] = 0

    # visualization
    if vis:
        my_pos = {}
        used_pos = set({})
        import math
        def l2d(x,y):
            return math.sqrt((x[0]-y[0])**2 + (x[1]-y[1])**2)
        for i in range(0,num_vert):
            f = 1
            new_pos = (randint(1,20),randint(1,20))
            while f:
                new_pos = (randint(1,20),randint(1,20))
                for p in used_pos: 
                    if l2d(p,new_pos) < 2:
                        f = 1
                        break
                else: f = 0
            my_pos[i] = new_pos
            used_pos.add(new_pos)

        OG = nx.from_numpy_matrix(np.array(source_graph))
        # my_pos = nx.spring_layout(OG, seed = 13)
        labels = nx.get_edge_attributes(OG,'weight')
        nx.draw(OG,with_labels = True,pos = my_pos)
        nx.draw_networkx_edge_labels(OG,my_pos,edge_labels=labels)
        ax = plt.gca()
        ax.set_title('original graph')
        plt.show()

        result_G = nx.from_numpy_matrix(np.array(parsed_adj_mat))
        # my_pos = nx.spring_layout(result_G, seed = 13)
        labels = nx.get_edge_attributes(result_G,'weight')
        nx.draw(result_G,with_labels = True,pos = my_pos)
        nx.draw_networkx_edge_labels(result_G,my_pos,edge_labels=labels)
        ax = plt.gca()
        ax.set_title('result graph')
        plt.show()

        expected_G = nx.from_numpy_matrix(msp_graph)
        # my_pos = nx.spring_layout(expected_G, seed = 13)
        labels = nx.get_edge_attributes(expected_G,'weight')
        nx.draw(expected_G,with_labels = True,pos = my_pos)
        nx.draw_networkx_edge_labels(expected_G,my_pos,edge_labels=labels)
        ax = plt.gca()
        ax.set_title('expected graph')
        plt.show()
        
    # final result
    # if verb:
    #     print('\nfinal result:')
    #     print(np.array(parsed_adj_mat))
    # if verb:
    #     print('expected: ')
    #     print(msp_graph)

    # test correctness on upper triangle
    res = sum(sum(np.array(parsed_adj_mat))) == sum(sum(msp_graph))
    if verb:
        print('result weight sum: ',sum(sum(np.array(parsed_adj_mat))))
        print('expect weight sum: ',sum(sum(msp_graph)))

    # test if identical/# differencce
    num_diff = 0
    for i in range(num_vert):
        for j in range(i,num_vert):
            if msp_graph[i][j] != parsed_adj_mat[i][j]:
                num_diff += 1
    if verb:
        if num_diff > 0: print('{} edge are different'.format(num_diff))
        else: print('msp identical')
    
    return res



In [5]:
total = 10
passes = 0
fails = 0
for _ in range(total):
    res = mst_test(num_vert = 64, unique=0,vis=0,seed=randint(1,1000),verb=0,scale=20)
    if res: passes += 1
    else: fails += 1
print('{}/{} cases passed'.format(passes,total))

0/10 cases passed
