In [8]:
import tempfile
import ot
import anndata
import matplotlib.pyplot as plt
import numpy as np
from spatial_OT.OT import *
from spatial_OT.utils import *
import scanpy as sc
import scvi
import seaborn as sns
import torch
from scipy.stats import spearmanr
from scvi.data import cortex, smfish
from scvi.external import GIMVI
import pandas as pd
from scipy.spatial import distance

In [9]:
save_dir = "/home/fceccarelli/home3/OT_simulation/code/publication_code/"
spatial_data = smfish(save_path=save_dir)
seq_data = cortex(save_path=save_dir)

[34mINFO    [0m File [35m/home/fceccarelli/home3/OT_simulation/code/publication_code/[0m[95mosmFISH_SScortex_mouse_all_cell.loom[0m     
         already downloaded                                                                                        
[34mINFO    [0m Loading smFISH dataset                                                                                    
[34mINFO    [0m File [35m/home/fceccarelli/home3/OT_simulation/code/publication_code/[0m[95mexpression.bin[0m already downloaded        
[34mINFO    [0m Loading Cortex data from [35m/home/fceccarelli/home3/OT_simulation/code/publication_code/[0m[95mexpression.bin[0m       
[34mINFO    [0m Finished loading Cortex data                                                                              




In [10]:
seq_data = seq_data[:, spatial_data.var_names].copy()
print (seq_data)
print (spatial_data)

AnnData object with n_obs × n_vars = 3005 × 33
    obs: 'labels', 'precise_labels', 'cell_type'
AnnData object with n_obs × n_vars = 4530 × 33
    obs: 'x_coord', 'y_coord', 'labels', 'str_labels', 'batch'
    uns: 'cell_types'


In [11]:
# gene split
# fix randomness
np.random.seed(42)
train_size = 0.8
n_genes = seq_data.n_vars
n_train_genes = int(n_genes * train_size)


rand_train_gene_idx = np.random.choice(n_genes, n_train_genes, replace=False)
rand_test_gene_idx = np.setdiff1d(np.arange(n_genes), rand_train_gene_idx)

rand_train_genes = seq_data.var_names[rand_train_gene_idx]
rand_test_genes = seq_data.var_names[rand_test_gene_idx]

# restrict both modalities
seq_data_partial = seq_data[:, rand_train_genes].copy()
spatial_data_partial = spatial_data[:, rand_train_genes].copy()

# remove empty observations
sc.pp.filter_cells(seq_data_partial, min_counts=1)
sc.pp.filter_cells(spatial_data_partial, min_counts=1)

# align spatial full object to partial one
spatial_data = spatial_data[spatial_data_partial.obs_names].copy()
seq_data = seq_data[seq_data_partial.obs_names].copy()

In [12]:
for ad in [seq_data_partial, spatial_data_partial]:
    sc.pp.normalize_total(ad, target_sum=1e4)
    sc.pp.log1p(ad)

In [13]:
def compute_spatial_fgw_alignment(spatial_slice, seq_slice, alpha, epsilon):
    """Compute TOAST alignment"""
   
    X = spatial_slice.X
    Y = seq_slice.X
    
    coords1 = spatial_slice.obs[["x_coord", "y_coord"]].copy()
    coords1.columns = ["x", "y"]
    X_df = pd.DataFrame(X, columns=[f"PC{i+1}" for i in range(X.shape[1])])
    X_df["x"], X_df["y"], X_df["cell_type"] = coords1["x"].values, coords1["y"].values, spatial_slice.obs["str_labels"].values

    G1 = build_knn_graph_from2d(X_df, k=k)
    X_df["spatial_entropy"] = X_df.index.map(compute_spatial_entropy(G1))
    slice1_avg_expr = compute_average_neighbor_expression(G1, pd.DataFrame(X))
    
    Y_df = pd.DataFrame(Y, columns=[f"PC{i+1}" for i in range(Y.shape[1])])
    cell_types = seq_slice.obs['cell_type']
    G2 = build_knn_graph_expression(Y_df, cell_types, k=k)

    Y_df["spatial_entropy"] = Y_df.index.map(compute_spatial_entropy(G2))
    slice2_avg_expr = compute_average_neighbor_expression(G2, pd.DataFrame(Y))
    
    # Compute cost matrices
    M = distance.cdist(X, Y).astype(float)
    coords = spatial_slice.obs[["x_coord", "y_coord"]].to_numpy()
    C1 = distance.cdist(coords, coords).astype(float)
    C2 = distance.cdist(Y, Y)
    C3 = np.abs(X_df["spatial_entropy"].values[:, np.newaxis] - Y_df["spatial_entropy"].values[np.newaxis, :])
    C4 = distance.cdist(slice1_avg_expr.values, slice2_avg_expr.values).astype(float)

    # Normalize matrices
    for mat in [M, C1, C2, C3, C4]:
        mat /= mat.max() if mat.max() > 0 else 1  # Avoid division by zero
    
    # Compute transport maps
    p, q = ot.unif(X.shape[0]), ot.unif(Y.shape[0])
    G0 = np.outer(p, q)

    TOAST = compute_transport(G0, epsilon, alpha, C1, C2, p, q, M, C3, C4)

    return TOAST

In [14]:
k = 10
TOAST = compute_spatial_fgw_alignment(spatial_data_partial, seq_data_partial, alpha=0.5, epsilon=0.5)

In [15]:
X_sc_full = seq_data.X
X_sc_full = (
    X_sc_full.toarray()
    if hasattr(X_sc_full, "toarray")
    else X_sc_full
)

X_imputed = TOAST @ X_sc_full

spatial_data.layers["imputed_ot"] = X_imputed

In [16]:
from scipy.stats import spearmanr
import numpy as np

test_idx = [spatial_data.var_names.get_loc(g) for g in rand_test_genes]

true = spatial_data.X[:, test_idx]
pred = spatial_data.layers["imputed_ot"][:, test_idx]

# convert sparse → dense if needed
true = true.toarray() if hasattr(true, "toarray") else true
pred = pred.toarray() if hasattr(pred, "toarray") else pred

gene_corrs = []
for i in range(len(test_idx)):
    # avoid constant vectors
    if np.std(true[:, i]) == 0 or np.std(pred[:, i]) == 0:
        continue
    corr, _ = spearmanr(true[:, i], pred[:, i])
    gene_corrs.append(corr)

print ("Mean Spearman Corr")
np.mean(gene_corrs)

Mean Spearman Corr


0.21380246202095263