In [2]:
from sklearn.metrics import pairwise_distances
from hdbscan import HDBSCAN
import numpy as np
import pandas as pd
from plotly import express as px
import os, sys
sys.path.append("..")

from plot import create_hover_plot
import numpy as np
import pandas as pd
from IPython.display import display, HTML


In [4]:
chunks_df = pd.read_parquet("../data/02_intermediate/chunks.parquet")
umap_projection = pd.read_parquet("../data/02_intermediate/umap_projection.parquet")
pca_projection = pd.read_parquet("../data/02_intermediate/pca_projection.parquet")
pca_projection.shape

(1420, 428)

In [80]:
dist_matrix = pairwise_distances(pca_projection, metric='cosine')
clusterer = HDBSCAN(min_cluster_size=2, metric="precomputed", min_samples=1).fit(dist_matrix)

In [81]:
import hdbscan

def get_exemplars(cluster_id, condensed_tree):
    raw_tree = condensed_tree._raw_tree
    # Just the cluster elements of the tree, excluding singleton points
    cluster_tree = raw_tree[raw_tree['child_size'] > 1]
    # Get the leaf cluster nodes under the cluster we are considering
    leaves = hdbscan.plots._recurse_leaf_dfs(cluster_tree, cluster_id)
    # Now collect up the last remaining points of each leaf cluster (the heart of the leaf)
    result = np.array([])
    for leaf in leaves:
        max_lambda = raw_tree['lambda_val'][raw_tree['parent'] == leaf].max()
        points = raw_tree['child'][(raw_tree['parent'] == leaf) &
                                   (raw_tree['lambda_val'] == max_lambda)]
        result = np.hstack((result, points))
    return result.astype(int)

tree = clusterer.condensed_tree_
cluster_ids = tree._select_clusters()
exemplars = {}
for c in cluster_ids:
    exemplars[c] = get_exemplars(c, tree)


In [68]:
create_hover_plot(umap_projection, chunks_df, clusterer.labels_)

In [82]:
sort_idx = np.argsort(clusterer.labels_)

In [83]:
from plotly import graph_objects as go

In [84]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

# Define a qualitative colormap
colormap = px.colors.qualitative.Light24

# Plot UMAP projection as scatter plot
fig = px.scatter(
    umap_projection.iloc[sort_idx].reset_index(),
    x=0,
    y=1,
    labels={'color': 'Cluster'},
    title='UMAP projection of chunks colored by cluster',
    hover_data={"index"}
).update_traces(marker=dict(color='lightgray', size=5)).update_layout(
    width=800,
    height=800
)

# Add exemplar points on top of all points
for cluster_id, exemplar_points in exemplars.items():
    # Convert cluster_id to integer and use modulo to cycle through colors
    color_index = int(np.round(cluster_id)) % len(colormap)
    cluster_color = colormap[color_index]
    
    fig.add_trace(
        go.Scatter(
            x=umap_projection.iloc[exemplar_points, 0],
            y=umap_projection.iloc[exemplar_points, 1],
            mode='markers',
            marker=dict(
                size=10,
                color=cluster_color,
                line=dict(width=1, color='Black')
            ),
            name=f'Cluster {cluster_id}',
            legendgroup=f'Cluster {cluster_id}',
            showlegend=True
        )
    )

fig.show()

In [85]:
exemplars_pca_projection = pd.DataFrame()
clusters = []

for cluster_id, exemplar_points in enumerate(exemplars.values()):
    df = pd.DataFrame(pca_projection.iloc[exemplar_points].values, index=exemplar_points)
    exemplars_pca_projection = pd.concat([exemplars_pca_projection, df], ignore_index=False)
    clusters.extend([cluster_id] * len(exemplar_points))

exemplars_pca_projection  = exemplars_pca_projection.sort_index()

In [64]:
exemplars_pca_projection.to_parquet("../data/02exemplars_intermediate/exemplars_pca_projection.parquet")

In [86]:
fig = px.scatter(
    umap_projection.iloc[exemplars_pca_projection.index].reset_index(),
    x=0,
    y=1,
    color=exemplars_pca_projection.index,
    color_continuous_scale="Turbo",
    labels={'color': 'Cluster'},
    title='UMAP projection of chunks colored by cluster',
    hover_data={"index"}
).update_traces(marker=dict( size=15)).update_layout(
    width=800,
    height=800
)
fig.show()