In [1]:
import struct
import numpy as np
import umap
import plotly.graph_objects as go
import plotly as px

In [2]:
def read_clustered_data(filename):
    with open(filename, "rb") as f:
        n_clusters = struct.unpack("I", f.read(4))[0]
        n_points = struct.unpack("I", f.read(4))[0]
        n_dim = struct.unpack("I", f.read(4))[0]

        # --- Read centroids ---
        centroid_data = np.zeros((n_clusters, n_dim), dtype=np.float32)
        centroid_ids = np.zeros(n_clusters, dtype=np.int32)
        for i in range(n_clusters):
            centroid_data[i] = np.frombuffer(f.read(4 * n_dim), dtype=np.float32)
            centroid_ids[i] = struct.unpack("i", f.read(4))[0]

        # --- Read points ---
        point_data = np.zeros((n_points, n_dim), dtype=np.float32)
        point_cluster_ids = np.zeros(n_points, dtype=np.int32)
        for i in range(n_points):
            point_data[i] = np.frombuffer(f.read(4 * n_dim), dtype=np.float32)
            point_cluster_ids[i] = struct.unpack("i", f.read(4))[0]

        # DEBUG
        print("n_clusters:", n_clusters)
        print("n_points:", n_points)
        print("n_dim:", n_dim)
        print("point_data shape:", point_data.shape)
        print("centroid_data shape:", centroid_data.shape)

    return point_data, point_cluster_ids, centroid_data, centroid_ids, n_dim

In [3]:
filename = "clustered_data"
point_data, point_cluster_ids, centroids, centroid_ids, n_dim = read_clustered_data(filename)

n_clusters: 4
n_points: 200
n_dim: 3
point_data shape: (200, 3)
centroid_data shape: (4, 3)


In [4]:
reducer = umap.UMAP(n_components=3)
points_3d = reducer.fit_transform(point_data)
centroids_3d = reducer.transform(centroids)



In [5]:
fig = go.Figure()

unique_clusters = centroid_ids.copy()
colors = px.colors.qualitative.Plotly
cluster_color_map = {cluster_id: colors[i % len(colors)] for i, cluster_id in enumerate(unique_clusters)}

for cluster_id in unique_clusters:
    # Add points
    mask = point_cluster_ids == cluster_id
    fig.add_trace(go.Scatter3d(
        x=points_3d[mask, 0],
        y=points_3d[mask, 1],
        z=points_3d[mask, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=cluster_color_map[cluster_id],
            opacity=0.5
        ),
        name=f'p - C{cluster_id}',
        showlegend=False
    ))


    # Add centroids
    mask = centroid_ids == cluster_id
    fig.add_trace(go.Scatter3d(
        x=centroids_3d[mask, 0],
        y=centroids_3d[mask, 1],
        z=centroids_3d[mask, 2],
        mode='markers+text',
        marker=dict(
            size=6,
            symbol='diamond',
            color=cluster_color_map[cluster_id],
            line=dict(
                color=cluster_color_map[cluster_id],
                width=5
            )
        ),
        name=f'Cluster {cluster_id}',
        text=f"C{cluster_id}",
        textposition="top center"
    ))

fig.update_layout(
    title="ðŸŒŒ UMAP 3D Projection of Clustered Data",
    margin=dict(l=0, r=0, b=0, t=40),
    scene=dict(
        xaxis_title='UMAP-1',
        yaxis_title='UMAP-2',
        zaxis_title='UMAP-3',
        bgcolor='white'
    ),
    legend=dict(x=0.02, y=0.98),
    showlegend=True
)

fig.show()