# K-Means Clustering & UMAP Visualization

In [1]:
import sys
import gzip
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans
import umap
import matplotlib.pyplot as plt
import seaborn as sns

src_dir = Path.cwd().parent

# sys.path strictly for importing modules
sys.path.append(str(src_dir))
from utils.data_utils import *

DATA_PATH = src_dir / "data" / "processed"
COHORT_PATH = src_dir / "data" / "processed" / "diabetic_patient_day_table.csv.gz"

In [2]:
all_embeddings = np.load(DATA_PATH / "patient_embeddings.npy")
patient_ids = np.load(DATA_PATH / "patient_ids.npy")

emb_df = pd.DataFrame(all_embeddings, index=patient_ids)
emb_df.index.name = "patient_id"

emb_df.head()

Unnamed: 0_level_0,0
patient_id,Unnamed: 1_level_1
16638841.0,0.365061
13006599.0,0.365025
15057166.0,0.36483
18529984.0,0.364873
19387056.0,0.364791


In [3]:
cohort = load_data(COHORT_PATH)
print(cohort.shape)
cohort.head()

(1372192, 65)


Unnamed: 0,subject_id,chartdate,50803,50809,50822,50824,50837,50841,50842,50847,...,n_admissions,first_admission_date,last_admission_date,hypertension_flag,ckd_flag,obesity_flag,neuropathy_flag,retinopathy_flag,heart_disease_flag,insulin_flag
0,10000635,2136-04-08,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
1,10000635,2138-09-29,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
2,10000635,2141-08-15,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
3,10000635,2142-12-23,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
4,10000635,2143-06-06,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False


## K-Means Clustering (Mini-Batch)

In [4]:
n_clusters = 5
mbk = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10000, random_state=42)

# Fit and predict
emb_df['cluster'] = mbk.fit_predict(all_embeddings)

print("Cluster counts:")
print(emb_df['cluster'].value_counts())

Cluster counts:
cluster
0    31738
3    10145
2     3124
4      941
Name: count, dtype: int64


In [None]:
reducer = umap.UMAP(
    n_neighbors=30,
    min_dist=0.1, 
    n_components=2, 
    random_state=42,
    verbose=True
)
umap_embeddings = reducer.fit_transform(all_embeddings)

umap_df = pd.DataFrame(
    umap_embeddings, 
    columns=['UMAP1', 'UMAP2'], 
    index=patient_ids
)
umap_df['cluster'] = emb_df['cluster']

plt.figure(figsize=(12, 10))
sns.scatterplot(
    data=umap_df,
    x='UMAP1',
    y='UMAP2',
    hue='cluster',
    palette='tab10',
    s=40,
    alpha=0.7
)
plt.title("Full-cohort Patient Embeddings UMAP (colored by cluster)")
plt.legend(title='Cluster')
plt.show()

UMAP(n_jobs=1, n_neighbors=30, random_state=42, verbose=True)
Thu Dec 11 22:30:35 2025 Construct fuzzy simplicial set
Thu Dec 11 22:30:35 2025 Finding Nearest Neighbors
Thu Dec 11 22:30:35 2025 Building RP forest with 16 trees


  warn(


Thu Dec 11 22:30:38 2025 NN descent for 15 iterations
	 1  /  15
	 2  /  15
	 3  /  15
	Stopping threshold met -- exiting after 3 iterations
Thu Dec 11 22:30:45 2025 Finished Nearest Neighbor Search
Thu Dec 11 22:30:46 2025 Construct embedding


In [20]:
print(umap_df.columns)
umap_df.head()

Index(['UMAP1', 'UMAP2', 'cluster'], dtype='object')


Unnamed: 0,UMAP1,UMAP2,cluster
16638841.0,-3.356588,19.633076,0
13006599.0,-1.331632,-7.334458,0
15057166.0,12.677441,-9.492103,2
18529984.0,13.321322,6.057528,3
19387056.0,16.099903,-4.441524,3


In [26]:
cohort['subject_id'] = cohort['subject_id'].astype(umap_df.index.dtype)
clustered_df = cohort.merge(
    umap_df[['cluster']],
    left_on='subject_id',
    right_index=True,
)
print(clustered_df.columns)
clustered_df.head()

Unnamed: 0,subject_id,chartdate,50803,50809,50822,50824,50837,50841,50842,50847,...,first_admission_date,last_admission_date,hypertension_flag,ckd_flag,obesity_flag,neuropathy_flag,retinopathy_flag,heart_disease_flag,insulin_flag,cluster
0,10000635.0,2136-04-08,,,,,,,,,...,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False,3
1,10000635.0,2138-09-29,,,,,,,,,...,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False,3
2,10000635.0,2141-08-15,,,,,,,,,...,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False,3
3,10000635.0,2142-12-23,,,,,,,,,...,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False,3
4,10000635.0,2143-06-06,,,,,,,,,...,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False,3


In [None]:
sns.lineplot(
    data=clustered_df,
    x='chartdate',
    y='age',
    hue='cluster',
    estimator='mean'
)

KeyboardInterrupt: 