# Coordinate Descent to Cluster Cells

We want to perform coordinate descent to cluster cells. We will initialize a clustering using a common method, such as spectral clustering. Then, we will use that clustering to fit maximum likelihood embryo centers. Next, we will set the cluster of each cell according to maximum likelihood. This will be repeated until convergence.

In [1]:
import pandas as pd
import anndata
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
from functools import partial
from multiprocessing import Pool
from sklearn.metrics import adjusted_rand_score

## Load Data

In [2]:
NUM_EMBRYOS = 3
UMI_CUTOFF = 2000

CALLS_SAVE_PATH = 'data/calls_2022_07_06.h5ad'
READS_SAVE_PATH = 'data/reads_2022_07_06.h5ad'

In [3]:
variants_joined = anndata.read_h5ad(CALLS_SAVE_PATH)
reads_joined = anndata.read_h5ad(READS_SAVE_PATH)

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


In [4]:
umi_sum = reads_joined.X.sum(axis=1)
cells_to_keep = umi_sum >= UMI_CUTOFF

variants_joined = variants_joined[cells_to_keep, :].copy()
reads_joined = reads_joined[cells_to_keep, :].copy()

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


## Run Descent

In [5]:
# Computes the log probability of each value of
# variant j in cluster i
def compute_best_log_prob_pi(j, i=None):
    # Get the calls for the variant for each cell in the cluster
    cell_calls = C[z == i, j]
    
    probs = []
    
    for pi_call in range(4):
        prob = 0
        
        for cell_call in cell_calls:
            if T[pi_call, cell_call] == 0:
                prob = -np.inf
                break
            
            prob += np.log(T[pi_call, cell_call])
            
        probs.append(prob)
        
    return np.argmax(probs)

# Sets pi by computing choosing the best log probability variant
# call for each cluster
def set_pi():
    pi = np.zeros((NUM_EMBRYOS, C.shape[1]), dtype=np.int8)
    
    for i in range(NUM_EMBRYOS):
        emb_prob_fun = partial(compute_best_log_prob_pi, i=i)
        
        p = Pool(NUM_THREADS)
        calls = list(tqdm(p.imap(emb_prob_fun, range(C.shape[1])), total=C.shape[1]))
        p.close()
                
        pi[i, :] = calls
            
    return pi

# Computes the log probabilities accross the possible clusters for
# C[i]
def compute_best_log_prob_C(i):
    probs = []
    
    for l in range(NUM_EMBRYOS):
        prob = 0
        
        for j in range(C.shape[1]):
            if T[pi[l, j], C[i, j]] == 0:
                prob = -np.inf
                break
                
            prob += np.log(T[pi[l, j], C[i, j]])
            
        probs.append(prob)
        
    return np.argmax(probs)

def set_z():    
    p = Pool(NUM_THREADS)
    z = np.array(list(tqdm(p.imap(compute_best_log_prob_C, range(C.shape[0])), total=C.shape[0])))
    p.close()
                            
    return z

In [8]:
NUM_THREADS = 90

# Our ground truth data
C = variants_joined.X
#z = np.random.randint(0, NUM_EMBRYOS, size=C.shape[0])
z = np.array(variants_joined.obs.embryo)
T = np.array([[1.        , 0.        , 0.        , 0.        ],
       [0.74778739, 0.25221261, 0.        , 0.        ],
       [0.84580846, 0.03814512, 0.10065051, 0.01539591],
       [0.74778739, 0.        , 0.        , 0.25221261]])

In [9]:
score_prev = -1

while 1:
    print('Setting pi')
    pi = set_pi()
    
    print('Setting z')
    z = set_z()
    
    score = adjusted_rand_score(variants_joined.obs.embryo, z)
    print('ARI is: ', score)
    
    if score-score_prev == 0:
        break
        
    score_prev=score

Setting pi


  0%|          | 0/27462 [00:00<?, ?it/s]

  0%|          | 0/27462 [00:00<?, ?it/s]

  0%|          | 0/27462 [00:00<?, ?it/s]

Setting z


  0%|          | 0/16480 [00:00<?, ?it/s]

ARI is:  1.0
Setting pi


  0%|          | 0/27462 [00:00<?, ?it/s]

  0%|          | 0/27462 [00:00<?, ?it/s]

  0%|          | 0/27462 [00:00<?, ?it/s]

Setting z


  0%|          | 0/16480 [00:00<?, ?it/s]

ARI is:  1.0
