In [1]:
import networkx as nx

def print_for_graphviz(mst: list):
    print("digraph G {")
    for edge in mst:
        print(str(edge.u) + " -> " + str(edge.v) + "[label=" + str(edge.weight) + "]")
    print("}")

#verification 
def subgraph(mst: list, parentEdgesArray:list):
    for i in mst:
        if i not in parentEdgesArray:
            return False
    return True

#is_tree will test for cycles and connectivity
def connected(mst:list):
    Graph = nx.DiGraph()
    for edge in mst:
        Graph.add_edge(edge.u, edge.v)
    if (nx.is_tree(Graph)==False):
            return False    
    return True

def has_required_vertices(mst: list, required_vertices: list):
    found_verts: set = set()
    for e in mst:
        if e.v in required_vertices:
            found_verts.add(e.v)
        if e.u in required_vertices:
            found_verts.add(e.u)
    return len(found_verts) == len(required_vertices)

def cost_of_graph(mst: list):
    cost = 0
    for edge in mst:
        cost += edge.weight
    return cost

def verify(mst: list, parentEdgesArray: list, required_vertices: list):
    if (not subgraph(mst, parentEdgesArray)):
        print("MST is not a subgraph of the parent graph")
    elif (not connected(mst)):
        print("MST is not connected")
    elif (not has_required_vertices(mst, required_vertices)):
        print("MST does not contain all required vertices")
    else:
        print("MST is complete, is a subgraph of parent, and contains all required vertices")


In [2]:
class Edge():
    def __init__(self,u: int, v: int, weight: int) -> None:
        self.u = u 
        self.v = v
        self.weight = weight

class DSU(): 
    def __init__(self, num_verts: int): 
        self.parent = [i for i in range(num_verts)]
        self.height = [1 for i in range(num_verts)]
    
    def find(self, x : int) -> int: 
        if x == self.parent[x]: 
            return x 
        self.parent[x] = self.find(self.parent[x]) 
        return self.parent[x]

    def union(self,x: int, y: int) -> bool: 
        x = self.find(x) #replace with set id 
        y = self.find(y) 
        if x == y: 
            return False #not part of different components
        if self.height[x] > self.height[y]: 
            self.parent[y] = x 
        elif self.height[x] == self.height[y]: 
            self.parent[y] = x 
            self.height[x]+=1 
        else: 
            self.parent[x] = y 
        return True #part of different components 


def kruskals(edges: list, num_verts: int) -> list:
    dsu = DSU(num_verts) 
    edges.sort(key = lambda edge: edge.weight)
    mst = []
    num_edges  = 0
    index = 0   
    while index < len(edges) and num_edges < num_verts-1: 
        if dsu.union(edges[index].u, edges[index].v):
            mst.append(edges[index]) 
            num_edges+=1 
        index+=1 
    return mst

In [3]:
with open('custom_input.txt', 'r') as fileIn:
    data1 = fileIn.readline()
    line = data1.split()
    intLine = [int (i) for i in line]
    parentVertices = intLine[0]
    edges = intLine[1]
    numOfRVertices = intLine[2]
    
    #second line
    data2 = fileIn.readline()
    line2 = data2.split()
    rVertices = [int (i) for i in line2]
    
    #list of all possible edges
    edgesArray = []
    for i in range(edges):
        data3 = fileIn.readline()
        line3 = data3.split()
        edgesArray.append(Edge(int(line3[0])-1,int(line3[1])-1, int(line3[2])))
    

In [4]:
mst = kruskals (edgesArray, parentVertices)
print("Cost of MST: ", cost_of_graph(mst))

# fix indexing to be 1-indexed
for edge in mst:
    edge.u += 1
    edge.v += 1

verify(mst, edgesArray, rVertices)



Cost of MST:  1101
MST is complete, is a subgraph of parent, and contains all required vertices


In [5]:
def prune(mst: list, v: int):
    for edge in mst:
        if edge.v == v or edge.u == v:
            mst.remove(edge)
    return mst

continue_pruning = True

while(continue_pruning):
    mst_size = len(mst)

    G2 = nx.Graph()
    for edge in mst:
        G2.add_edge(edge.u, edge.v)

    for v in G2.nodes():
        if G2.degree[v] == 1 and v not in rVertices:
            mst = prune(mst, v)

    if mst_size == len(mst):
        continue_pruning = False

print("Cost of MST: ", cost_of_graph(mst))
verify(mst, edgesArray, rVertices)


Cost of MST:  587
MST is complete, is a subgraph of parent, and contains all required vertices


In [6]:
import make_input

#make_input.make_input()