diff --git a/README.md b/README.md index 06fb2d4c4..1fed65ab4 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ If you want to uninstall algorithms, it is as simple as: - [markov_chain](algorithms/graph/markov_chain.py) - [minimum_spanning_tree](algorithms/graph/minimum_spanning_tree.py) - [satisfiability](algorithms/graph/satisfiability.py) + - [minimum_spanning_tree_prims](algorithms/graph/prims_minimum_spanning.py) - [tarjan](algorithms/graph/tarjan.py) - [traversal](algorithms/graph/traversal.py) - [maximum_flow](algorithms/graph/maximum_flow.py) diff --git a/algorithms/graph/prims_minimum_spanning.py b/algorithms/graph/prims_minimum_spanning.py new file mode 100644 index 000000000..02295c4c2 --- /dev/null +++ b/algorithms/graph/prims_minimum_spanning.py @@ -0,0 +1,42 @@ +''' +This Prim's Algorithm Code is for finding weight of minimum spanning tree +of a connected graph. +For argument graph, it should be a dictionary type +such as +graph = { + 'a': [ [3, 'b'], [8,'c'] ], + 'b': [ [3, 'a'], [5, 'd'] ], + 'c': [ [8, 'a'], [2, 'd'], [4, 'e'] ], + 'd': [ [5, 'b'], [2, 'c'], [6, 'e'] ], + 'e': [ [4, 'c'], [6, 'd'] ] +} + +where 'a','b','c','d','e' are nodes (these can be 1,2,3,4,5 as well) +''' + + +import heapq # for priority queue + +# prim's algo. to find weight of minimum spanning tree +def prims_minimum_spanning(graph_used): + vis=[] + s=[[0,1]] + prim = [] + mincost=0 + + while(len(s)>0): + v=heapq.heappop(s) + x=v[1] + if(x in vis): + continue + + mincost += v[0] + prim.append(x) + vis.append(x) + + for j in graph_used[x]: + i=j[-1] + if(i not in vis): + heapq.heappush(s,j) + + return mincost diff --git a/tests/test_graph.py b/tests/test_graph.py index 83d35293b..325b5d896 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -10,6 +10,7 @@ from algorithms.graph import bellman_ford from algorithms.graph import bellman_ford from algorithms.graph import count_connected_number_of_component +from algorithms.graph import prims_minimum_spanning import unittest @@ -264,3 +265,23 @@ def test_connected_components_without_edges_graph(self): expected_result = 4 result = count_connected_number_of_component.count_components(l,size) self.assertEqual(result,expected_result) + + +class PrimsMinimumSpanning(unittest.TestCase): + def test_prim_spanning(self): + graph1 = { + 1 : [ [3, 2], [8, 3] ], + 2 : [ [3, 1], [5, 4] ], + 3 : [ [8, 1], [2, 4], [4, 5] ], + 4 : [ [5, 2], [2, 3], [6, 5] ], + 5 : [ [4, 3], [6, 4] ] + } + self.assertEqual(14, prims_minimum_spanning(graph1)) + + graph2 = { + 1 : [ [7, 2], [6, 4] ], + 2 : [ [7, 1], [9, 4], [6, 3] ], + 3 : [ [8, 4], [6, 2] ], + 4 : [ [6, 1], [9, 2], [8, 3] ] + } + self.assertEqual(19, prims_minimum_spanning(graph2))