In [3]:
"""
Given an undirected graph with n nodes labeled 1..n. Some of the nodes are already connected. The i-th edge 
connects nodes edges[i][0] and edges[i][1] together. Your task is to augment this set of edges with additional 
edges to connect all the nodes. Find the minimum cost to add new edges between the nodes such that all the nodes 
are accessible from each other.

Input:

n, an int representing the total number of nodes.
edges, a list of integer pair representing the nodes already connected by an edge.
newEdges, a list where each element is a triplet representing the pair of nodes between which an edge can be 
added and the cost of addition, respectively (e.g. [1, 2, 5] means to add an edge between node 1 and 2, the 
cost would be 5).

Input: n = 6, edges = [[1, 4], [4, 5], [2, 3]], newEdges = [[1, 2, 5], [1, 3, 10], [1, 6, 2], [5, 6, 5]]
Output: 7
Explanation:
There are 3 connected components [1, 4, 5], [2, 3] and [6].
We can connect these components into a single component by connecting node 1 to node 2 and node 1 to node 6 
at a minimum cost of 5 + 2 = 7.

"""
def compute_min_cost(num_nodes, base_mst, poss_mst):
    uf = {}

    # create union find for the initial edges given 
    def find(edge):
        uf.setdefault(edge, edge)
        if uf[edge] != edge:
            uf[edge] = find(uf[edge])
        return uf[edge]

    def union(edge1, edge2):
        uf[find(edge1)] = find(edge2)

    for e1, e2 in base_mst:
        if find(e1) != find(e2):
            union(e1, e2)

    # sort the new edges by cost
    # if an edge is not part of the minimum spanning tree, then include it, else continue
    cost_ret = 0
    for c1, c2, cost in sorted(poss_mst, key=lambda x : x[2]):
        if find(c1) != find(c2):
            union(c1, c2)
            cost_ret += cost

    if len({find(c) for c in uf}) == 1 and len(uf) == num_nodes:
        return cost_ret
    else:
        return -1

def minCost(n, edges, newEdges):
    node = [0]+[-1]*n
    cost = {}
    for edge in edges:
        while node[edge[0]] != -1:
            edge[0] = node[edge[0]]
        node[edge[1]] = edge[0]
        
    rootNum = sum([-e for e in node if e == -1])
    # find the total number of root
    for edge in newEdges:
        if node[edge[0]] != -1:
            edge[0] = node[edge[0]]
        if node[edge[1]] != -1:
            edge[1] = node[edge[1]]
        if edge[0] == edge[1]:
            continue
        # the newEdges becomes [[1, 2, 5], [1, 2, 10], [1, 6, 2], [1, 6, 5]]
        if edge[0] > edge[1]:
            key = str(edge[1])+str(edge[0])
        key = str(edge[0]) + str(edge[1])
            
        if key not in cost:
            cost[key] = edge[2]
        else:
            cost[key] = min(cost[key], edge[2])
                
    if len(cost) == rootNum - 1:
        return sum(cost.values())
    elif len(cost) == rootNum:
        return sum(cost.values()) - max(cost.values())
    else:
        return "unable to connect all nodes"

if __name__ == '__main__':
    n = 6
    edges = [[1, 4], [4, 5], [2, 3]]
    new_edges = [[1, 2, 5], [1, 3, 10], [1, 6, 2], [5, 6, 5]]
    print(compute_min_cost(n, edges, new_edges))
    
    print(minCost(n, edges, new_edges))

7
7
