In [2]:
import numpy as np
import pandas as pd
from spatial_OT.OT import *
from spatial_OT.utils import *
from scipy.spatial import distance
import numpy as np
import scanpy as sc
import anndata as ad
import ot
import matplotlib.pyplot as plt
import warnings
import time
warnings.filterwarnings('ignore')

In [3]:
consecutive = [["151507", "151508"], ["151508", "151509"], ["151509", "151510"],
                ["151669", "151670"], ["151670", "151671"], ["151671", "151672"],
                ["151673", "151674"], ["151674", "151675"], ["151675", "151676"]]

non_consecutive = [["151507", "151509"], ["151507", "151510"], ["151508", "151510"],
                ["151669", "151671"], ["151669", "151672"], ["151670", "151672"],
                ["151673", "151675"], ["151673", "151676"], ["151674", "151676"]]

cross_sample = [["151507", "151669"], ["151507", "151673"], ["151669", "151673"],
                ["151508", "151670"], ["151508", "151674"], ["151670", "151674"],
                ["151509", "151671"], ["151509", "151675"], ["151671", "151675"],
                ["151510", "151672"], ["151510", "151676"], ["151672", "151676"]]

In [4]:
# Store results
results = []

def process_slice(slice_path):
    """Load and preprocess a spatial transcriptomics slice."""
    slice_data = sc.read_visium(path=f"../LIBD/{slice_path}", count_file=f"{slice_path}_filtered_feature_bc_matrix.h5")
    slice_data.obs_names_make_unique()
    slice_data.var_names_make_unique()
    
    ann = pd.read_csv(f"../LIBD/{slice_path}/{slice_path}_truth.txt", sep='\t', header=None, index_col=0)
    slice_data.obs["gt"] = ann.loc[slice_data.obs.index, 1]
    
    # Remove NA cells
    slice_data = slice_data[~slice_data.obs["gt"].isna()]
    
    return slice_data

def compute_spatial_fgw_alignment(slice1, slice2, alpha, epsilon):
    """Compute FGW and FGW-SN alignment and return accuracy and JS divergence scores."""
    joint_adata = sc.concat([slice1, slice2])
    sc.pp.normalize_total(joint_adata, inplace=True)
    sc.pp.log1p(joint_adata)
    sc.pp.pca(joint_adata, n_comps=n_comps)
    joint_datamatrix = joint_adata.obsm['X_pca']
    
    X = joint_datamatrix[:slice1.shape[0], :]
    Y = joint_datamatrix[slice1.shape[0]:, :]
    
    # Compute spatial graphs
    coords1 = pd.DataFrame(slice1.obsm["spatial"], columns=["x", "y"])
    X_df = pd.DataFrame(X, columns=[f"PC{i+1}" for i in range(X.shape[1])])
    X_df["x"], X_df["y"], X_df["cell_type"] = coords1["x"].values, coords1["y"].values, slice1.obs["gt"].values

    G1 = build_knn_graph_from2d(X_df, k=k)
    X_df["spatial_entropy"] = X_df.index.map(compute_spatial_entropy(G1))
    slice1_avg_expr = compute_average_neighbor_expression(G1, pd.DataFrame(X))
    
    coords2 = pd.DataFrame(slice2.obsm["spatial"], columns=["x", "y"])
    Y_df = pd.DataFrame(Y, columns=[f"PC{i+1}" for i in range(Y.shape[1])])
    Y_df["x"], Y_df["y"], Y_df["cell_type"] = coords2["x"].values, coords2["y"].values, slice2.obs["gt"].values

    G2 = build_knn_graph_from2d(Y_df, k=k)
    Y_df["spatial_entropy"] = Y_df.index.map(compute_spatial_entropy(G2))
    slice2_avg_expr = compute_average_neighbor_expression(G2, pd.DataFrame(Y))
    
    # Compute cost matrices
    M = distance.cdist(X, Y).astype(float)
    C1 = distance.cdist(slice1.obsm["spatial"], slice1.obsm["spatial"]).astype(float)
    C2 = distance.cdist(slice2.obsm["spatial"], slice2.obsm["spatial"]).astype(float)
    C3 = np.abs(X_df["spatial_entropy"].values[:, np.newaxis] - Y_df["spatial_entropy"].values[np.newaxis, :])
    C4 = distance.cdist(slice1_avg_expr.values, slice2_avg_expr.values).astype(float)

    # Normalize matrices
    for mat in [M, C1, C2, C3, C4]:
        mat /= mat.max() if mat.max() > 0 else 1  # Avoid division by zero
    
    # Compute transport maps
    p, q = ot.unif(X.shape[0]), ot.unif(Y.shape[0])
    G0 = np.outer(p, q)

    FGW = compute_transport(G0, epsilon, alpha, C1, C2, p, q, M)
    FGW_SN = compute_transport(G0, epsilon, alpha, C1, C2, p, q, M, C3, C4)

    # Compute accuracy
    acc_fgw = compute_accuracy_max_prob(FGW, slice1.obs['gt'], slice2.obs['gt'])
    acc_fgw_sn = compute_accuracy_max_prob(FGW_SN, slice1.obs['gt'], slice2.obs['gt'])
    
    # Compute JS divergence
    js_fgw = compute_js_divergence_before_after(slice1, compute_transported_adata_argmax(slice1, slice2, FGW), k=20, cell_type_key="gt")
    js_fgw_sn = compute_js_divergence_before_after(slice1, compute_transported_adata_argmax(slice1, slice2, FGW_SN), k=20, cell_type_key="gt")

    return acc_fgw, acc_fgw_sn, js_fgw, js_fgw_sn

In [None]:
n_comps = 50  
k = 10 
alpha = 0.5  
epsilon = 0.1  

# Run for all consecutive slices
for i in consecutive:
    print(f"Processing alignment: {i[0]} - {i[1]}")
    
    # Load and process slices
    slice1 = process_slice(i[0])
    slice2 = process_slice(i[1])

    # Compute alignment
    acc_fgw, acc_fgw_sn, js_fgw, js_fgw_sn = compute_spatial_fgw_alignment(slice1, slice2, alpha, epsilon)

    # Store results
    results.append({
        "Slice1": i[0], "Slice2": i[1],
        "FGW_Accuracy": acc_fgw, "FGW-SN_Accuracy": acc_fgw_sn,
        "FGW_JS_Divergence": js_fgw, "FGW-SN_JS_Divergence": js_fgw_sn
    })

    print(f"FGW Accuracy: {acc_fgw:.3f}, FGW-SN Accuracy: {acc_fgw_sn:.3f}")
    print(f"FGW JS Divergence: {js_fgw:.4f}, FGW-SN JS Divergence: {js_fgw_sn:.4f}\n")


Processing alignment: 151507 - 151508
FGW Accuracy: 0.654, FGW-SN Accuracy: 0.735
FGW JS Divergence: 0.3268, FGW-SN JS Divergence: 0.2346

