In [None]:
import networkx as nx
import matplotlib.pyplot as plt


def prim(graph):
    """Return the minimum spanning tree of a connected, weighted graph using Prim's algorithm."""
    mst = nx.Graph()
    nodes = set(graph.nodes)
    start = nodes.pop()
    mst.add_node(start)

    while nodes:
        edges = [(u, v, graph[u][v]['weight']) for u in mst.nodes for v in graph[u] if v in nodes]
        u, v, w = sorted(edges, key=lambda x: x[2])[0]
        mst.add_node(v)
        mst.add_edge(u, v, weight=w)
        nodes.remove(v)

    return mst


def visualize(graph):
    """Visualize a graph using networkx and matplotlib."""
    pos = nx.spring_layout(graph)
    edge_labels = {(u, v): graph[u][v]['weight'] for u, v in graph.edges}

    nx.draw_networkx_nodes(graph, pos)
    nx.draw_networkx_edges(graph, pos)
    nx.draw_networkx_labels(graph, pos)
    nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels)

    plt.axis('off')
    plt.show()

G = nx.Graph()
G.add_edge('A', 'B', weight=4)
G.add_edge('A', 'C', weight=1)
G.add_edge('B', 'C', weight=2)
G.add_edge('B', 'D', weight=5)
G.add_edge('C', 'D', weight=3)
visualize(G)
mst = prim(G)
visualize(mst)
