## **Problem 3:** Design a PySpark program to implement Prim’s algorithm for Minimum Spanning Tree

In [5]:
import math
class Node:
    def __init__(self, value, neighbors=None):
        self.value = value
        if neighbors is None:
            self.neighbors = []
        else:
            self.neighbors = neighbors
        self.length_from_previous_node = math.inf
        self.previous_node = None
        self.visited = False

    def has_neighbors(self):
        if len(self.neighbors) == 0:
            return False
        return True

    def number_of_neighbors(self):
        return len(self.neighbors)

    def add_neighbor(self, neighbor):
        self.neighbors.append(neighbor)

    def __eq__(self, other):
        return self.value == other.value
    
    def __gt__(self, other):
        return self.length_from_previous_node > other.length_from_previous_node

    def __str__(self):
        return f"{self.previous_node} -> {self.value}"


In [6]:
class Graph:
    def __init__(self, nodes=None):
        if nodes is None:
            self.nodes = []
        else:
            self.nodes = nodes

    def add_node(self, node):
        self.nodes.append(node)

    def find_node(self, value):
        for node in self.nodes:
            if node.value == value:
                return node 
        return None

    def add_edge(self, value1, value2, weight=1):
        node1 = self.find_node(value1)        
        node2 = self.find_node(value2)
        if (node1 is not None) and (node2 is not None):
            node1.add_neighbor((node2, weight))
            node2.add_neighbor((node1, weight))
        else:
            print("Error: One or more nodes were not found")

    def number_of_nodes(self):
        return f"The graph has {len(self.nodes)} nodes"

    def are_connected(self, node_one, node_two):
        node_one = self.find_node(node_one)
        node_two = self.find_node(node_two)

        for neighbor in node_one.neighbors:
            if neighbor[0].value == node_two.value:
                return True
        return False

    def __str__(self):
        graph = ""
        for node in self.nodes:
            graph += f"{node.__str__()}\n" 
        return graph

In [7]:
class Prim:
    def __init__(self, graph, start):
        self.graph = graph
        self.start = start
        self.tree = []
        self.vertices = self.graph.nodes

    def calculate_total_cost(self):
        total_cost = 0
        for node in self.tree:
            total_cost += node.length_from_previous_node
        return total_cost

    def execution(self):
        selected_node = self.graph.find_node(self.start)
        selected_node.length_from_previous_node = 0
        selected_node.visited = True
        self.vertices.remove(selected_node)
        self.tree.append(selected_node)
        for node in selected_node.neighbors:
            child = node[0]
            if node[1] < child.length_from_previous_node:
                child.length_from_previous_node = node[1]
                child.previous_node = selected_node.value

        while len(self.vertices) > 0:
            self.vertices.sort()
            selected_node = self.vertices[0]
            selected_node.visited = True
            self.vertices.remove(selected_node)
            # Add the selected node to the tree
            self.tree.append(selected_node)
            for node in selected_node.neighbors:
                child = node[0]
                if not child.visited:
                    if node[1] < child.length_from_previous_node:
                        child.length_from_previous_node = node[1]
                        child.previous_node = selected_node.value

        total_cost = self.calculate_total_cost()
        return self.tree, total_cost

In [10]:
def main():
    graph = Graph()
    graph.add_node(Node('1'))
    graph.add_node(Node('2'))
    graph.add_node(Node('3'))
    graph.add_node(Node('4'))
    graph.add_node(Node('5'))
    graph.add_node(Node('6'))
    graph.add_node(Node('7'))
    graph.add_node(Node('8'))
    graph.add_node(Node('9'))
    graph.add_edge('1', '2', 4)
    graph.add_edge('2', '3', 11)
    graph.add_edge('2', '5', 8)
    graph.add_edge('3', '4', 7)
    graph.add_edge('3', '6', 1)
    graph.add_edge('4', '5', 2)
    graph.add_edge('4', '6', 6)
    graph.add_edge('5', '7', 4)
    graph.add_edge('5', '8', 7)
    graph.add_edge('6', '7', 2)
    graph.add_edge('7', '8', 14)
    graph.add_edge('7', '9', 10)
    graph.add_edge('8', '9', 9)

    alg = Prim(graph, "1")
    tree, total_cost = alg.execution()
    print("Minimum Spanning Tree by Prim's Algorithm")
    for node in tree:
        print(node)
    print(f"Total Cost: {total_cost}")

if __name__ == '__main__':
    main()

Minimum Spanning Tree by Prim's Algorithm
None -> 1
1 -> 2
2 -> 5
5 -> 4
5 -> 7
7 -> 6
6 -> 3
5 -> 8
8 -> 9
Total Cost: 37


**The Minimum Spanning Tree of the given graph is 37**