In [1]:
import scanpy as sc
import squidpy as sq
import numpy as np
import pandas as pd
from anndata import AnnData
import pathlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import skimage
import seaborn as sns
import tangram as tg

from scipy.spatial import distance
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
!ls -l candidates/

total 764804
-rw-rw-r-- 1 apon apon 307208568 May 12 19:09 candidates.npy
-rw-rw-r-- 1 apon apon 307208568 May 13 16:01 candidates_75.npy
-rw-rw-r-- 1 apon apon  56238752 May 14 11:50 candidates_80_tg.npy
-rw-rw-r-- 1 apon apon  56238752 May 15 13:56 candidates_90_tg.npy
-rw-rw-r-- 1 apon apon  56238752 May 15 14:00 candidates_95_tg.npy


In [2]:
candidates = np.load("candidates/candidates_95_tg.npy")

In [3]:
candidates.sum()

53469

### Tg tutorial data

In [3]:
adata_st = sq.datasets.visium_fluo_adata_crop()
adata_st = adata_st[
    adata_st.obs.cluster.isin([f"Cortex_{i}" for i in np.arange(1, 5)])
].copy()
img = sq.datasets.visium_fluo_image_crop()

adata_sc = sq.datasets.sc_mouse_cortex()

### Create mock data & mock candidates

In [8]:
# 1. Create a 100x100 array with values from N(50, 10)
mock_adata_sc = np.random.normal(loc=50, scale=10, size=(100, 100))

# 2. Create a 10x50 array with values from N(50, 10)
mock_adata_st = np.random.normal(loc=50, scale=10, size=(10, 50))

# 3. Create a 100x10 logical (binary) array where each column has 10 ones
mock_candidates = np.zeros((100, 10), dtype=int)

for col in range(10):
    row_start = col * 10
    mock_candidates[row_start:row_start+10, col] = 1

### EM

In [4]:
class EM_postprocess:
    
    def __init__(self, candidates, adata_sc, adata_st, n_cells_spot, spatial_coords, max_iter, alpha, beta):
        
        self.S = {} #Contains indices of best fitting cells per voxel (index_voxel: [index_cells])
        self.candidates = candidates
        self.adata_sc = adata_sc
        self.adata_st = adata_st
        self.n_cells_spot = n_cells_spot
        self.spatial_coords = spatial_coords
        self.max_iter = max_iter
        self.alpha = alpha
        self.beta = beta
        
        self.n_voxels = self.adata_st.shape[0]
        
        shared_genes = adata_sc.var_names.intersection(adata_st.var_names, sort=False)
        self.adata_sc_shared_genes = adata_sc[:, adata_sc.var_names.isin(shared_genes)].copy()
        self.adata_st_shared_genes = adata_st[:, adata_st.var_names.isin(shared_genes)].copy()
        
    def _init_candidates(self, v):
        """
        input: one-hot encoded vector
        output: np.array (len <= n_cells_spot) with indices of the selected cells to represent the spot at init
        if n_cells_spot > candidates it returns all the candidates.
        """
        one_indices = np.where(v == 1)[0]
        
        if len(one_indices) >= self.n_cells_spot:
            selected = np.random.choice(one_indices, size=self.n_cells_spot, replace=False)
        else:
            selected = one_indices
    
        return selected
    
    def _get_candidates(self, v):
        """
        Input: one-hot encoded vector
        Output: np.array with the indices of the candidates of a voxel.
        """
        
        return np.where(v == 1)[0]
    
    def _get_neighbors(self, v, radius=300):
        """
        Given a voxel index and a radius, return indices of voxels within radius.
        Output: np.ndarray: Indices of neighboring voxels.
        """
        
        ref_coord = self.spatial_coords[v]

        # Compute Euclidean distances from the reference voxel to all others
        dists = np.linalg.norm(self.spatial_coords - ref_coord, axis=1)

        # Get indices of voxels within the radius (excluding the voxel itself)
        neighbor_indices = np.where((dists > 0) & (dists <= radius))[0]
        
        return neighbor_indices
    
    def _cosine_similarity(self, vec1, vec2):
        """
        Computes cosine similarity between two vectors.
        """
        vec1 = vec1.ravel()
        vec2 = vec2.ravel()
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    
    def _neighbors_similarity(self, vec1, neighbors):
        """
        Vec1 is a vector.
        neighbors is an array containing the index of the neighboring voxels.
        
        This function averages the expression profile of ALL cells in neighboring voxels 
        and then performs cosine similarity between the resulting mean expression vector 
        and our candidate vector.
        """
        neighbor_cells = [cell for n in neighbors for cell in self.S[n]] #list with index of all currently chosen neighbor cells
        mean_neigh_expr = self.adata_sc.X[neighbor_cells].toarray().mean(axis=0)
        return self._cosine_similarity(vec1, mean_neigh_expr)
    
    def _compare_S(self, S, S_new):
        """
        Compares if previous selection of cells in all voxels are the same in the new iteration.
        Checks if two dictionaries are identical.
        """
        for key in S:
            if sorted(S[key]) != sorted(S_new[key]):
                return False

        return True
           
    def run_EM(self):
        self.logs = []  # List to store iteration logs
        self.neighbors = dict()

        #----- init -----
        for v in range(self.n_voxels):
            self.S[v] = self._init_candidates(self.candidates[:, v])
            self.neighbors[v] = self._get_neighbors(v, radius=300)

        #----- E-M steps -----
        for iteration in range(self.max_iter):
            S_new = {}
            used_cells = set()  # Track cells already assigned to prevent reuse
            voxel_order = np.random.permutation(self.n_voxels)  # Shuffle voxel update order

            for v in voxel_order:
                candidates = self._get_candidates(self.candidates[:, v])
                available_candidates = [c for c in candidates if c not in used_cells]

                # If no candidates available (all have been used), skip voxel
                if len(available_candidates) == 0:
                    S_new[v] = []
                    continue

                neighbors = self.neighbors[v]
                scores = []

                for c in available_candidates:
                    # E-step: Compute similarity scores
                    vec_c = self.adata_sc_shared_genes.X[c, :].toarray() #candidate expr pattern for sim_voxel
                    vec_cc = self.adata_sc.X[c, :].toarray() #candidate expr pattern for sim_neighbors

                    sim_voxel = self._cosine_similarity(vec_c, self.adata_st_shared_genes.X[v, :].toarray())
                    sim_neighbors = self._neighbors_similarity(vec_cc, neighbors)
                    score = self.alpha * sim_voxel + self.beta * sim_neighbors
                    scores.append(score)

                scores = np.array(scores)
                gamma = np.exp(scores - np.max(scores))  # Softmax normalization
                gamma /= gamma.sum()

                # M-step: Select top-n_cells_spot based on gamma (limited by availability)
                top_idx = np.argsort(gamma)[-min(self.n_cells_spot, len(available_candidates)):]
                selected = [available_candidates[i] for i in top_idx]

                S_new[v] = selected
                used_cells.update(selected)

            # Log progress
            self.logs.append({
                "iteration": iteration,
                "n_voxels_updated": sum(len(v) > 0 for v in S_new.values()),
                "n_cells_used": len(used_cells),
            })

            print(f"Iteration {iteration+1}: {len(used_cells)} cells assigned across {self.n_voxels} voxels.")

            # Check for convergence
            if self._compare_S(self.S, S_new):
                print(f"Converged at iteration {iteration+1}.")
                break

            self.S = S_new
    
    def assess_EM(self):
        """
        Computes the cos sim between the average expr of all cells in a voxel and the expr of the same voxel
        in spatial data.
        
        Computes the cos sim between the mean expr of all cells in a voxel and the mean expr of all neighbor voxels
        in spatial data.
        """
        self.result_vox_sim = []
        self.result_neigh_sim = []

        for v in self.S:
            #Cos sim between our selected cells mean expr and the voxel gene expr
            mean_gene_expr_v = self.adata_sc_shared_genes.X[self.S[v]].toarray().mean(axis=0) #mean gene expr of cells in voxel v
            self.result_vox_sim.append(self._cosine_similarity(mean_gene_expr_v, self.adata_st_shared_genes.X[v, :].toarray()))
            
            #Cos sim between our selected cells mean expr and neighboring voxels
            neighbors = self.neighbors[v]
            mean_gene_expr_v = self.adata_sc.X[self.S[v]].toarray().mean(axis=0)
            self.result_neigh_sim.append(self._neighbors_similarity(mean_gene_expr_v, neighbors))
        
        print("Find results in self.result_vox_sim and self.result_neigh_sim")
        

In [5]:
post_process = EM_postprocess(candidates, adata_sc, adata_st, n_cells_spot=31, spatial_coords=adata_st.obsm["spatial"], max_iter=1, alpha=1.0, beta=1.0)

In [6]:
post_process.run_EM()

Iteration 1: 10044 cells assigned across 324 voxels.


In [7]:
post_process.assess_EM()

Find results in self.result_vox_sim and self.result_neigh_sim


In [8]:
post_process.result_neigh_sim

[0.9948187,
 0.9957233,
 0.99351114,
 0.9954435,
 0.99409884,
 0.9966083,
 0.9958935,
 0.9951325,
 0.9963773,
 0.9945666,
 0.9951374,
 0.995827,
 0.9948302,
 0.99586225,
 0.9964565,
 0.99580836,
 0.9953191,
 0.99373466,
 0.99537504,
 0.99578077,
 0.99225044,
 0.9966192,
 0.9953544,
 0.99561065,
 0.99522847,
 0.99576217,
 0.99665827,
 0.99477303,
 0.9967063,
 0.9917535,
 0.99468267,
 0.99489117,
 0.99121743,
 0.99685854,
 0.99592644,
 0.9948886,
 0.99479556,
 0.99580353,
 0.995791,
 0.9964875,
 0.9951145,
 0.99660337,
 0.9931792,
 0.99602234,
 0.9960419,
 0.9958243,
 0.9937087,
 0.99007726,
 0.99607813,
 0.99551153,
 0.9943489,
 0.9958487,
 0.99566287,
 0.9960729,
 0.9966957,
 0.99639875,
 0.9945065,
 0.9950533,
 0.9928419,
 0.9941247,
 0.9962023,
 0.9962427,
 0.9960808,
 0.99073404,
 0.9937926,
 0.99593085,
 0.99379116,
 0.995616,
 0.996926,
 0.995987,
 0.9950623,
 0.99487966,
 0.9956533,
 0.9962867,
 0.99591017,
 0.9966149,
 0.9962183,
 0.9966303,
 0.99453807,
 0.996611,
 0.99492973,


In [9]:
post_process.result_vox_sim

[0.3786023,
 0.38453674,
 0.39881983,
 0.3963732,
 0.36689952,
 0.36689523,
 0.4005415,
 0.3754332,
 0.38283512,
 0.3856827,
 0.39343023,
 0.35015407,
 0.39047593,
 0.39105162,
 0.4016832,
 0.35448858,
 0.39058974,
 0.39976946,
 0.39225593,
 0.3663721,
 0.39244872,
 0.40435046,
 0.35959187,
 0.39784166,
 0.3942073,
 0.38560835,
 0.3915639,
 0.3763769,
 0.39562875,
 0.3878175,
 0.33515513,
 0.357306,
 0.3895898,
 0.38766146,
 0.38254434,
 0.38498288,
 0.37821043,
 0.38786674,
 0.36814207,
 0.38303545,
 0.37129822,
 0.38524067,
 0.35530126,
 0.38299826,
 0.38098612,
 0.383817,
 0.39047962,
 0.38945848,
 0.3329333,
 0.34766072,
 0.31744134,
 0.38213524,
 0.39598393,
 0.39965865,
 0.4024471,
 0.37784615,
 0.39541057,
 0.37173837,
 0.3893257,
 0.39285153,
 0.36523864,
 0.39544162,
 0.36717328,
 0.37635005,
 0.3809074,
 0.3978137,
 0.38057777,
 0.35669243,
 0.39545566,
 0.39328155,
 0.37405646,
 0.3819462,
 0.3908902,
 0.39348882,
 0.4041733,
 0.38690647,
 0.3906395,
 0.38944757,
 0.3902794,

In [48]:
print(post_process.S)

{129: [10049, 10054, 10116, 10117, 9432, 8266, 8246, 8174, 5728, 5858, 5937, 5976, 6068, 6070, 6167, 6550, 6577, 6640, 6819, 6826, 6833, 6842, 7092, 7234, 7269, 7323, 7793, 7909, 8153, 5718, 21673], 212: [860, 344, 2052, 3897, 3947, 3959, 8898, 8733, 8536, 8476, 8352, 8296, 8268, 8187, 8021, 7464, 7306, 7168, 7127, 6936, 6878, 6450, 6429, 6405, 6147, 5942, 5484, 4732, 4119, 21512, 21670], 258: [9807, 9889, 9914, 9957, 9990, 9999, 10077, 9786, 10208, 8716, 8467, 6498, 6659, 6759, 6803, 7216, 7361, 7536, 7640, 8715, 7737, 8169, 8195, 8335, 8349, 8350, 8368, 8369, 8436, 7814, 21506], 117: [7876, 8251, 8286, 8446, 8459, 6245, 6086, 6079, 6067, 3858, 3978, 3984, 4002, 4036, 4083, 4090, 4217, 4484, 3637, 4558, 5030, 5481, 5548, 5766, 5809, 5844, 5872, 5943, 5962, 4613, 21313], 142: [12687, 12701, 10337, 8821, 10251, 10131, 9121, 9193, 9205, 9226, 9230, 9320, 9331, 9336, 9405, 9440, 9487, 9556, 9680, 9689, 9777, 9799, 9803, 9881, 9971, 9977, 9992, 9997, 10072, 10241, 21639], 21: [9494, 9788, 

In [49]:
with open("em_postprocess_log.pkl", "wb") as f:
    pickle.dump(post_process.logs, f)

In [50]:
with open("em_post_result.pkl", "wb") as f:
    pickle.dump(post_process.S, f)


In [2]:
with open("results/em_300iter_result.pkl", "rb") as f:
    results = pickle.load(f)

In [4]:
with open("results/em_300iter_log.pkl", "rb") as f:
    log = pickle.load(f)

In [3]:
print(results)

{0: [10499, 10534, 10571, 10627, 10688, 11004, 11029, 10470, 21574, 9525, 9394, 6702, 6730, 6772, 6825, 7104, 7120, 7254, 7314, 9398, 7352, 7791, 8039, 8186, 8203, 9052, 9142, 9209, 9227, 7689, 21686], 286: [11873, 11876, 11893, 11894, 8993, 5426, 8539, 8198, 5592, 5681, 5786, 5856, 5863, 5889, 6060, 6072, 6260, 6423, 6480, 6836, 6917, 6937, 7282, 7625, 7657, 7700, 7814, 7973, 8146, 8246, 21682], 316: [7291, 7225, 7029, 6794, 6758, 6392, 6390, 6166, 6146, 6094, 5979, 5970, 5828, 7900, 5734, 5605, 5582, 5432, 5368, 5224, 5159, 4594, 4588, 4550, 4532, 4340, 4338, 4309, 4245, 5657, 21375], 167: [8237, 8245, 8341, 7953, 8599, 8706, 8724, 8859, 9125, 9185, 9213, 8700, 9272, 7925, 7904, 6251, 6400, 6511, 6570, 6572, 6710, 7915, 6746, 6810, 7045, 7423, 7608, 7830, 7851, 6797, 20850], 116: [10281, 10301, 10355, 10356, 10395, 9833, 10412, 8569, 8022, 4149, 4678, 4863, 5145, 5336, 5460, 5500, 5562, 5656, 8408, 5711, 6189, 6218, 6430, 6707, 6756, 6952, 7008, 7587, 7882, 6186, 21555], 82: [5680, 1

In [11]:
print(log)

[{'iteration': 0, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 1, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 2, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 3, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 4, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 5, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 6, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 7, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 8, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 9, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 10, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 11, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 12, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 13, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'iteration': 14, 'n_voxels_updated': 324, 'n_cells_used': 10044}, {'it