In [45]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import distance
import networkx as nx

def generate_random_points(center, scale, m):
    """
    Generate m random points with different location and scale parameters.

    Args:
        center (array-like): Center point of length 3.
        scale (array-like): Scale parameters of length 3.
        m (int): Number of points to generate.

    Returns:
        points (ndarray): Array of shape (m, 3) containing the generated points.
    """
    points = np.random.normal(loc=center, scale=scale, size=(m, 3))
    return points

def plot_points(points,title=None):
    """
    Plot points in 3D space.

    Args:
        points1 (ndarray): Array of shape (n, 3) containing the first set of points.
        points2 (ndarray): Array of shape (m, 3) containing the second set of points.
        title (str): Title of the plot.
    """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(points[:m, 0], points[:m, 1], points[:m, 2], c='b', label=f'Set 1(n={m})', alpha=1)
    ax.scatter(points[m:, 0], points[m:, 1], points[m:, 2], c='g', label=f'Set 2(m={m})', alpha=1)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    if title:
        ax.set_title(title)
    ax.legend()
    plt.show()

def plot_minimum_spanning_tree(points, tree):
    """
    Plot minimum spanning tree in 3D space.

    Args:
        points (ndarray): Array of shape (n, 3) containing the points.
        tree (nx.Graph): Graph representing the minimum spanning tree.
    """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot points
    ax.scatter(points[:m, 0], points[:m, 1], points[:m, 2], c='b', label=f'Set 1(n={m})', alpha=1)
    ax.scatter(points[m:, 0], points[m:, 1], points[m:, 2], c='g', label=f'Set 2(m={m})', alpha=1)

    # Plot minimum spanning tree edges
    for (u, v) in tree.edges():
        ax.plot([points[u, 0], points[v, 0]],
                [points[u, 1], points[v, 1]],
                [points[u, 2], points[v, 2]], c='r')

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('Minimum Spanning Tree')

    # Create a legend
    ax.legend()

    plt.show()

def plot_leaf_nodes1(points, tree, m):
    # Plot only the leaf nodes of the minimum spanning tree that belong to the first sample
    leaf_nodes = [node for node in tree.nodes() if tree.degree(node) == 1 and node < m]

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot points
    ax.scatter(points[:m, 0], points[:m, 1], points[:m, 2], c='b', label=f'Set 1(n={m})', alpha=1)
    ax.scatter(points[m:, 0], points[m:, 1], points[m:, 2], c='g', label=f'Set 2(m={m})', alpha=1)

    # Plot leaf nodes
    ax.scatter(points[leaf_nodes, 0], points[leaf_nodes, 1], points[leaf_nodes, 2], c='r', label=f'Leaf Nodes(x1={len(leaf_nodes)})', alpha=1)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('Leaf Nodes of Minimum Spanning Tree')

    # Create a legend
    ax.legend()

    plt.show()
    
def plot_leaf_nodes2(points, tree, m):
    # Plot only the leaf nodes of the minimum spanning tree that belong to the first sample
    leaf_nodes = [node for node in tree.nodes() if tree.degree(node) == 1 and node > m]

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot points
    ax.scatter(points[:m, 0], points[:m, 1], points[:m, 2], c='b', label=f'Set 1(n={m})', alpha=1)
    ax.scatter(points[m:, 0], points[m:, 1], points[m:, 2], c='g', label=f'Set 2(m={m})', alpha=1)

    # Plot leaf nodes
    ax.scatter(points[leaf_nodes, 0], points[leaf_nodes, 1], points[leaf_nodes, 2], c='r', label=f'Leaf Nodes(y1={len(leaf_nodes)})', alpha=1)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('Leaf Nodes of Minimum Spanning Tree')

    # Create a legend
    ax.legend()

    plt.show()
# Parameters for random points with different location and scale
centers = [[5, 5, 5], [-5, -5, -5]]
scales = [[2, 2, 2], [3, 3, 3]]

# Generate m random points with different location and scale
m = 50
random_points1 = generate_random_points(centers[0], scales[0], m)
random_points2 = generate_random_points(centers[1], scales[1], m)

# Combine the two sets of points
all_points = np.concatenate([random_points1, random_points2], axis=0)

# Compute pairwise distances between points
distances = distance.cdist(all_points, all_points)

# Create a complete graph
graph = nx.Graph()
n = len(all_points)
for i in range(n):
    for j in range(i + 1, n):
        graph.add_edge(i, j, weight=distances[i, j])

# Compute minimum spanning tree using Kruskal's algorithm
mst = nx.minimum_spanning_tree(graph)

# Plot points and minimum spanning tree
plot_points(all_points)
plot_minimum_spanning_tree(all_points, mst)
plot_leaf_nodes1(all_points ,mst, m)
plot_leaf_nodes2(all_points, mst, m)