In [25]:
import struct
import numpy as np
import umap
import plotly.graph_objects as go
import plotly.express as px

In [26]:
def read_all_iterations(filename):
    iterations = []

    with open(filename, "rb") as f:
        # Read header once
        n_clusters = struct.unpack("I", f.read(4))[0]
        n_points = struct.unpack("I", f.read(4))[0]
        n_dim = struct.unpack("I", f.read(4))[0]

        # Determine how many bytes one iteration takes
        bytes_per_centroid = (4 * n_dim) + 4  # centroid values + id
        bytes_per_point = (4 * n_dim) + 4     # point values + cluster_id
        bytes_per_iteration = n_clusters * bytes_per_centroid + n_points * bytes_per_point

        file_data = f.read()
        total_bytes = len(file_data)
        n_iterations = total_bytes // bytes_per_iteration

        print(f"📦 Detected {n_iterations} iterations")

        offset = 0
        for iter_idx in range(n_iterations):
            # --- Read centroids ---
            centroid_data = np.zeros((n_clusters, n_dim), dtype=np.float32)
            centroid_ids = np.zeros(n_clusters, dtype=np.int32)
            for i in range(n_clusters):
                start = offset
                centroid_data[i] = np.frombuffer(file_data[start:start + 4 * n_dim], dtype=np.float32)
                offset += 4 * n_dim
                centroid_ids[i] = struct.unpack("i", file_data[offset:offset + 4])[0]
                offset += 4

            # --- Read points ---
            point_data = np.zeros((n_points, n_dim), dtype=np.float32)
            point_cluster_ids = np.zeros(n_points, dtype=np.int32)
            for i in range(n_points):
                start = offset
                point_data[i] = np.frombuffer(file_data[start:start + 4 * n_dim], dtype=np.float32)
                offset += 4 * n_dim
                point_cluster_ids[i] = struct.unpack("i", file_data[offset:offset + 4])[0]
                offset += 4

            iterations.append({
                'points': point_data,
                'point_cluster_ids': point_cluster_ids,
                'centroids': centroid_data,
                'centroid_ids': centroid_ids,
                'n_points': n_points,
                'n_centroids': n_clusters
            })

    return n_clusters, n_points, n_dim, iterations

In [27]:
filename = "./build/data/cluster_data"
n_clusters, n_points, n_dim, iterations = read_all_iterations(filename)

📦 Detected 10 iterations


In [28]:
print("✅ First iteration summary:")
print("Centroids:\n", iterations[0]['centroids'])
print("Points shape:", iterations[0]['points'].shape)
print("Unique cluster IDs:", np.unique(iterations[0]['centroid_ids']))


✅ First iteration summary:
Centroids:
 [[ 186.8906   ]
 [  87.09044  ]
 [  65.93164  ]
 [ -26.194336 ]
 [  -1.4538269]
 [-154.16461  ]
 [  13.540939 ]
 [  46.56109  ]
 [  57.317913 ]
 [  31.976303 ]]
Points shape: (70, 1)
Unique cluster IDs: [0 1 2 3 4 5 6 7 8 9]


In [29]:
all_points = []
all_centroids = []
all_point_ids = []
iteration_meta = []

for i, it in enumerate(iterations):
    all_points.append(it['points'])
    all_centroids.append(it['centroids'])
    all_point_ids.append(it['point_cluster_ids'])

    iteration_meta.append({
        'n_points': len(it['points']),
        'n_centroids': len(it['centroids']),
        'point_cluster_ids': it['point_cluster_ids'],
        'centroid_ids': it['centroid_ids'],
    })

# Stack all points for reduction
points_concat = np.vstack(all_points)
centroids_concat = np.vstack(all_centroids)

if points_concat.shape[1] > 3:
    reducer = umap.UMAP(
        n_components=3,
        n_neighbors=50,
        min_dist=1,
        metric='euclidean',
        low_memory=True,
    )
    points_umap = reducer.fit_transform(points_concat)
    centroids_umap = reducer.transform(centroids_concat)
elif points_concat.shape[1] < 3:
    padding = ((0, 0), (0, 3 - points_concat.shape[1]))
    points_umap = np.pad(points_concat, padding, mode='constant', constant_values=0)
    centroids_umap = np.pad(centroids_concat, padding, mode='constant', constant_values=0)
else:
    points_umap = points_concat
    centroids_umap = centroids_concat


In [30]:
points_3d_per_iter = []
centroids_3d_per_iter = []

p_idx, c_idx = 0, 0
for meta in iteration_meta:
    np_ = meta['n_points']
    nc_ = meta['n_centroids']
    points_3d_per_iter.append(points_umap[p_idx:p_idx+np_])
    centroids_3d_per_iter.append(centroids_umap[c_idx:c_idx+nc_])
    p_idx += np_
    c_idx += nc_


In [31]:
def voxel_downsample(points, cluster_ids, voxel_size=0.1, max_points=10_000):
    max_points = int(max_points)
    down_points, down_ids = [], []
    for cid in np.unique(cluster_ids):
        pts = points[cluster_ids == cid]
        if len(pts) == 0:
            continue
        vox = np.floor(pts / voxel_size).astype(int)
        _, idx = np.unique(vox, axis=0, return_index=True)
        reduced = pts[idx]
        if len(reduced) > max_points:
            reduced = reduced[np.random.choice(len(reduced), max_points, replace=False)]
        down_points.append(reduced)
        down_ids.append(np.full(len(reduced), cid, dtype=np.int32))
    return np.vstack(down_points), np.concatenate(down_ids)

reduced_points_3d_per_iter = []
reduced_cluster_ids_per_iter = []
max_points_per_iter = 3500

for i, (pts3d, meta) in enumerate(zip(points_3d_per_iter, iteration_meta)):
    if len(pts3d) > max_points_per_iter:
        reduced_pts, reduced_ids = voxel_downsample(pts3d, meta['point_cluster_ids'], max_points=max_points_per_iter/n_clusters)
        reduced_points_3d_per_iter.append(reduced_pts)
        reduced_cluster_ids_per_iter.append(reduced_ids)
    else:
        reduced_points_3d_per_iter.append(pts3d)
        reduced_cluster_ids_per_iter.append(meta['point_cluster_ids'])

# Replace old lists with reduced versions
points_3d_per_iter = reduced_points_3d_per_iter
for i in range(len(iteration_meta)):
    iteration_meta[i]['point_cluster_ids'] = reduced_cluster_ids_per_iter[i]


In [32]:
# Bernat: any setting <1 will cause sorting order artifacts because of
# the rendering engine's static sorting order on transparents.
point_opacity = 1

colors = px.colors.qualitative.Plotly

# Static layout
layout = go.Layout(
    title="🌌 UMAP 3D K-Means Iterations",
    margin=dict(l=0, r=0, b=0, t=40),
    scene=dict(
        xaxis_title='UMAP-1',
        yaxis_title='UMAP-2',
        zaxis_title='UMAP-3',
        bgcolor='white'
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[dict(label='Play', method='animate', args=[None])]
    )],
    sliders=[dict(
        steps=[dict(method='animate', args=[[f'frame{k}']], label=str(k)) for k in range(len(iterations))],
        transition=dict(duration=0),
        x=0.156, y=0, len=0.6
    )]
)

# Initial empty figure
fig = go.Figure(layout=layout)

# Initial data for frame 0
iter0 = iteration_meta[0]
cmap = {cid: colors[i % len(colors)] for i, cid in enumerate(np.unique(iter0['centroid_ids']))}

for cid in cmap:
    mask = iter0['point_cluster_ids'] == cid
    fig.add_trace(go.Scatter3d(
        x=points_3d_per_iter[0][mask, 0],
        y=points_3d_per_iter[0][mask, 1],
        z=points_3d_per_iter[0][mask, 2],
        mode='markers',
        marker=dict(size=4, color=cmap[cid], opacity=point_opacity),
        showlegend=False,
        name = 'C' + str(cid)
    ))
    mask = iter0['centroid_ids'] == cid
    fig.add_trace(go.Scatter3d(
        x=centroids_3d_per_iter[0][mask, 0],
        y=centroids_3d_per_iter[0][mask, 1],
        z=centroids_3d_per_iter[0][mask, 2],
        mode='markers+text',
        marker=dict(size=6, symbol='diamond', color=cmap[cid], line=dict(width=0.5, color='black')),
        name=f'Cluster {cid}',
        text=[f"C{cid}"] * np.sum(mask),
        textposition="top center"
    ))

# Add frames
frames = []
for k, meta in enumerate(iteration_meta):
    frame_data = []
    cmap = {cid: colors[i % len(colors)] for i, cid in enumerate(np.unique(meta['centroid_ids']))}

    for cid in cmap:
        mask_p = meta['point_cluster_ids'] == cid
        mask_c = meta['centroid_ids'] == cid

        frame_data.append(go.Scatter3d(
            x=points_3d_per_iter[k][mask_p, 0],
            y=points_3d_per_iter[k][mask_p, 1],
            z=points_3d_per_iter[k][mask_p, 2],
            mode='markers',
            marker=dict(size=4, color=cmap[cid], opacity=point_opacity),
            showlegend=False
        ))
        frame_data.append(go.Scatter3d(
            x=centroids_3d_per_iter[k][mask_c, 0],
            y=centroids_3d_per_iter[k][mask_c, 1],
            z=centroids_3d_per_iter[k][mask_c, 2],
            mode='markers+text',
            marker=dict(size=6, symbol='diamond', color=cmap[cid], line=dict(width=0.5, color='black')),
            name=f'Cluster {cid}',
            text=[f"C{cid}"] * np.sum(mask_c),
            textposition="top center"
        ))

    frames.append(go.Frame(data=frame_data, name=f"frame{k}"))

fig.frames = frames
fig.show()
