# Kmeans centroid points

In [None]:
import os
import os.path
import shutil
import logging
import matplotlib
matplotlib.use('Agg')  # Set the backend to Agg
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import cryodrgn
from cryodrgn import analysis, utils, config
from cryodrgn.starfile import Starfile
import pandas as pd
from cryodrgn.source import ImageSource
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist

import pickle

In [2]:
def cluster_kmeans(
    z: np.ndarray, K: int, on_data: bool = True, reorder: bool = True
):
    """
    Cluster z by K means clustering
    Returns cluster labels, cluster centers
    If reorder=True, reorders clusters according to agglomerative clustering of cluster centers
    """
    kmeans = KMeans(n_clusters=K, random_state=0, max_iter=10)
    labels = kmeans.fit_predict(z)
    centers = kmeans.cluster_centers_

    centers_ind = None
    if on_data:
        centers, centers_ind = get_nearest_point(z, centers)

    if reorder:
        g = sns.clustermap(centers)
        reordered = g.dendrogram_row.reordered_ind
        centers = centers[reordered]
        if centers_ind is not None:
            centers_ind = centers_ind[reordered]
        tmp = {k: i for i, k in enumerate(reordered)}
        labels = np.array([tmp[k] for k in labels])
    return labels, centers

def get_nearest_point(
    data: np.ndarray, query: np.ndarray
):
    """
    Find closest point in @data to @query
    Return datapoint, index
    """
    ind = cdist(query, data).argmin(axis=1)
    return data[ind], ind

In [3]:
import pickle
filepath = '/scratch/gpfs/ZHONGE/mj7341/research/00_moml/antibody/dataset/conformational/integrated_poses_chimera.pkl'
with open(filepath, 'rb') as file:
    poses = pickle.load(file)

In [4]:
rots = poses[0].reshape(len(poses[0]),-1)
trans = poses[1]

In [5]:
star_path = '/scratch/gpfs/ZHONGE/mj7341/research/00_moml/antibody/dataset/conformational/add_noise/128_chimera_resample/snr01/snr01_star.star'
ori_s = Starfile.load(star_path)
ori_df = ori_s.df

**UMap**

In [None]:
rots = poses[0].reshape(len(poses[0]),-1)
umap_emb_rot = analysis.run_umap(rots)
umap_emb_trans = analysis.run_umap(poses[1]) # translation (rotation & translation)

**PCA**

In [None]:
from sklearn.decomposition import PCA
def run_pca(z: np.ndarray):
    pca = PCA(z.shape[1])
    pca.fit(z)
    pc = pca.transform(z)
    return pc, pca

In [None]:
rots = poses[0].reshape(len(poses[0]),-1)
pc_rot, pca_rot = run_pca(rots)
pc_trans, pca_trans = run_pca(poses[1])

### 87 ground truth kmeans

In [23]:
ind_dir = '/scratch/gpfs/ZHONGE/mj7341/research/00_moml/antibody/dataset/conformational/cryosparc'

In [1]:
K=10
rots_inds = np.zeros((87,K)).astype(int)
trans_inds = np.zeros((87,K)).astype(int)
for i in range(87):
    idx = np.arange(1000*i,1000*(i+1))
    rots_kmeans_labels, rots_centers=cluster_kmeans(rots[idx],K)
    _, rots_centers_ind = get_nearest_point(rots[idx], rots_centers)
    
    trans_kmeans_labels, trans_centers=cluster_kmeans(trans[idx],K)
    _, trans_centers_ind = get_nearest_point(trans[idx], trans_centers)
    rots_inds[i] = rots_centers_ind
    trans_inds[i] = trans_centers_ind

np.save(f"{ind_dir}/rots_inds_87models_k"+str(K)+".npy", rots_inds)
np.save(f"{ind_dir}/trans_inds_87models_k"+str(K)+".npy", trans_inds)

In [None]:
# figure
fig, axs = plt.subplots(1,2, figsize=(10,5))
for i in range(87):
    idx = np.arange(1000*i,1000*(i+1))    
    axs[0].scatter(pc_rot[rots_inds[i]][:, 0], pc_rot[rots_inds[i]][:, 1], alpha=1, s=10, rasterized=True, c=hexcodes[i])
    axs[1].scatter(pc_trans[trans_inds[i]][:, 0], pc_trans[trans_inds[i]][:, 1], alpha=1, s=10, rasterized=True, c=hexcodes[i])

axs[0].set_title('Rots')
axs[1].set_title('Trans')
for ax in axs.flat:
    ax.set(xlabel='UMAP1', ylabel='UMAP2')

plt.tight_layout()
plt.savefig(f"{ind_dir}/poses_87models_kmeans"+str(K)+"_PCs.png")

### 3D Classification (num classes = 5)

- plot only centroids of kmeans

In [6]:
ind_dir = '/scratch/gpfs/ZHONGE/mj7341/research/00_moml/antibody/dataset/conformational/cryosparc'

In [7]:
# original code: 20231210_fsc_auc
cls_5 = []
num_classes = 5
for i in range(num_classes):
    with open(ind_dir+'/3dcls/3dcls_cs'+str(i)+'.pkl', 'rb') as file:
        cs_idx = pickle.load(file)
    lst = []
    for key, value in cs_idx.items():
        lst.append(np.array(value))
    stacked_array = np.hstack(lst)
    cls_5.append(stacked_array)

In [9]:
K=10 # number of clusters
rots_inds = []
trans_inds = []
for i in range(len(cls_5)):
    rots_kmeans_labels, rots_centers=cluster_kmeans(rots[cls_5[i]],K)
    _, rots_centers_ind = get_nearest_point(rots[cls_5[i]], rots_centers)
    
    trans_kmeans_labels, trans_centers=cluster_kmeans(trans[cls_5[i]],K)
    _, trans_centers_ind = get_nearest_point(trans[cls_5[i]], trans_centers)

    rots_inds.append(rots_centers_ind)
    trans_inds.append(trans_centers_ind)

  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)


**UMap**

In [16]:
labels = np.arange(5)
fig, axs = plt.subplots(1,2, figsize=(10,5))

for i in range(5): # 5: class
    axs[0].scatter(umap_emb_rot[rots_inds[i]][:, 0], umap_emb_rot[rots_inds[i]][:, 1], alpha=1, s=30, rasterized=True, label=labels[i])
    axs[1].scatter(umap_emb_trans[trans_inds[i]][:, 0], umap_emb_trans[trans_inds[i]][:, 1], alpha=1, s=30, rasterized=True, label=labels[i])
    
axs[0].set_title('Rots')
axs[1].set_title('Trans')
for ax in axs.flat:
    ax.set(xlabel='UMAP1', ylabel='UMAP2')
axs[0].legend()
axs[1].legend()

plt.tight_layout()
plt.savefig(f"{ind_dir}/3dcls/poses_kmeans"+str(K)+"_umap.png")

**PCA**

In [21]:
fig, axs = plt.subplots(1,2, figsize=(10,5))
labels = np.arange(K)
for i in range(5):
    axs[0].scatter(pc_rot[rots_inds[i]][:, 0], pc_rot[rots_inds[i]][:, 1], alpha=1, s=30, rasterized=True, label=labels[i])
    axs[1].scatter(pc_trans[trans_inds[i]][:, 0], pc_trans[trans_inds[i]][:, 1], alpha=1, s=30, rasterized=True, label=labels[i])
    
axs[0].set_title('Rots')
axs[1].set_title('Trans')
for ax in axs.flat:
    ax.set(xlabel='UMAP1', ylabel='UMAP2')
axs[0].legend()
axs[1].legend()
plt.tight_layout()
plt.savefig(f"{ind_dir}/3dcls/poses_kmeans"+str(K)+"_PCs.png")