In [None]:
import scanpy as sc
import numpy as np
import wcd_vae
from wcd_vae.scCRAFT.model import train_integration_model, obtain_embeddings
from wcd_vae.scCRAFT.utils import multi_resolution_cluster
import scvi
import scib 
import harmonypy as hm
import pandas as pd
import scanorama
import time
import bbknn
import scDML
import imap
from scib.utils import *
import torch

In [None]:
def plot_umap_by_technology(adata, batch_key='tech', color_key='celltype', ncols=3, figsize_per_panel=(5, 5)):
    """
    Plot UMAP with consistent x and y scales and consistent colors for each technology/batch.
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object with UMAP coordinates in obsm['X_umap']
    batch_key : str, default 'tech'
        Key in adata.obs containing batch/technology information
    color_key : str, default 'celltype'
        Key in adata.obs for coloring points
    ncols : int, default 3
        Maximum number of columns in subplot grid
    figsize_per_panel : tuple, default (5, 5)
        Size of each subplot panel
    
    Returns:
    --------
    None (displays plots)
    """
    import matplotlib.pyplot as plt
    import scanpy as sc
    import numpy as np
    import pandas as pd
    
    sc.tl.umap(adata, min_dist=0.5)
    
    # Ensure cell types are categorical
    if not pd.api.types.is_categorical_dtype(adata.obs[color_key]):
        adata.obs[color_key] = adata.obs[color_key].astype('category')
    
    # Get unique technologies/batches and cell types
    technologies = adata.obs[batch_key].unique()
    cell_types = adata.obs[color_key].cat.categories
    
    # Create a consistent colormap for cell types
    cmap = plt.cm.get_cmap('tab20', len(cell_types))
    colors = [cmap(i) for i in range(len(cell_types))]
    color_dict = dict(zip(cell_types, colors))
    
    # Get the overall x and y limits from the full UMAP
    x_coords = adata.obsm['X_umap'][:, 0]
    y_coords = adata.obsm['X_umap'][:, 1]
    x_min, x_max = x_coords.min() - 0.5, x_coords.max() + 0.5
    y_min, y_max = y_coords.min() - 0.5, y_coords.max() + 0.5
    
    # Create subplots - adjust the number of columns based on preference
    n_techs = len(technologies)
    ncols = min(ncols, n_techs)
    nrows = (n_techs + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(figsize_per_panel[0]*ncols, figsize_per_panel[1]*nrows))
    
    # Handle single subplot case
    if nrows == 1 and ncols == 1:
        axes = [axes]
    elif nrows == 1 or ncols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Plot each technology separately
    for i, tech in enumerate(technologies):
        ax = axes[i]
        
        # Subset data for this technology
        tech_mask = adata.obs[batch_key] == tech
        tech_coords = adata.obsm['X_umap'][tech_mask]
        tech_celltypes = adata.obs.loc[tech_mask, color_key]
        
        # Plot each cell type with consistent colors
        for cell_type in cell_types:
            cell_mask = tech_celltypes == cell_type
            if np.sum(cell_mask) > 0:  # Only plot if there are cells of this type
                ax.scatter(
                    tech_coords[cell_mask, 0], 
                    tech_coords[cell_mask, 1],
                    color=color_dict[cell_type],
                    s=1, alpha=0.7, label=cell_type
                )
        
        # Set consistent limits for all subplots
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        
        # Set labels and title
        ax.set_xlabel('UMAP1')
        ax.set_ylabel('UMAP2')
        ax.set_title(f'{batch_key.capitalize()}: {tech}')
        ax.set_aspect('equal')
    
    # Hide any unused subplots
    for i in range(n_techs, len(axes)):
        axes[i].set_visible(False)
    
    # Add a legend to the figure (outside the plots)
    handles, labels = [], []
    for cell_type in cell_types:
        handles.append(plt.Line2D([0], [0], marker='o', color=color_dict[cell_type], 
                                 label=cell_type, markersize=5, linestyle='None'))
        labels.append(cell_type)
    
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.15, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to make room for legend
    plt.show()
    
    # Also create a combined plot with all technologies
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    
    # Plot colored by technology/batch
    sc.pl.umap(adata, color=batch_key, ax=ax1, frameon=False, show=False)
    ax1.set_xlim(x_min, x_max)
    ax1.set_ylim(y_min, y_max)
    ax1.set_title(f'Colored by {batch_key.capitalize()}')
    
    # Plot colored by celltype with consistent colors
    for cell_type in cell_types:
        cell_mask = adata.obs[color_key] == cell_type
        if np.sum(cell_mask) > 0:
            ax2.scatter(
                adata.obsm['X_umap'][cell_mask, 0],
                adata.obsm['X_umap'][cell_mask, 1],
                color=color_dict[cell_type],
                s=1, alpha=0.7, label=cell_type
            )
    
    ax2.set_xlim(x_min, x_max)
    ax2.set_ylim(y_min, y_max)
    ax2.set_title(f'Colored by {color_key.capitalize()}')
    
    # Add legend to the second plot
    handles, labels = [], []
    for cell_type in cell_types:
        handles.append(plt.Line2D([0], [0], marker='o', color=color_dict[cell_type], 
                                 label=cell_type, markersize=5, linestyle='None'))
        labels.append(cell_type)
    
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.15, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to make room for legend
    plt.show()

In [None]:
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from scipy.stats import entropy
import scanpy as sc
from tqdm import tqdm

def compute_lisi(X, metadata, label_colname, perplexity=30):
    """
    Compute Local Inverse Simpson Index (LISI) for batch mixing evaluation.
    
    Parameters:
    -----------
    X : array-like, shape (n_samples, n_features)
        The embedded data matrix
    metadata : pandas.DataFrame
        Metadata containing batch/label information
    label_colname : str
        Column name in metadata containing the batch labels
    perplexity : int, default=30
        Perplexity parameter for Gaussian kernel (similar to t-SNE)
    
    Returns:
    --------
    lisi_scores : array-like
        LISI score for each cell
    """
    n_cells = X.shape[0]
    
    # Get batch labels
    batch_labels = metadata[label_colname].values
    unique_batches = np.unique(batch_labels)
    n_batches = len(unique_batches)
    
    # Create mapping from batch to index
    batch_to_idx = {batch: idx for idx, batch in enumerate(unique_batches)}
    batch_indices = np.array([batch_to_idx[batch] for batch in batch_labels])
    
    # Find k-nearest neighbors (k should be larger than perplexity)
    k = min(90, n_cells - 1)  # Use 90 neighbors or n_cells-1 if smaller
    print(f"Computing {k} nearest neighbors for {n_cells} cells...")
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X)
    distances, indices = nbrs.kneighbors(X)
    
    lisi_scores = np.zeros(n_cells)
    
    # Add progress bar for LISI computation
    print(f"Computing LISI scores for {label_colname}...")
    for i in tqdm(range(n_cells), desc="Computing LISI"):
        # Get neighbors and distances for current cell
        neighbor_indices = indices[i, 1:]  # Exclude self (index 0)
        neighbor_distances = distances[i, 1:]
        
        # Compute Gaussian kernel weights with adaptive bandwidth
        # Find bandwidth that gives desired perplexity
        sigma = find_sigma(neighbor_distances, perplexity)
        weights = np.exp(-neighbor_distances**2 / (2 * sigma**2))
        weights = weights / np.sum(weights)  # Normalize
        
        # Get batch labels of neighbors
        neighbor_batches = batch_indices[neighbor_indices]
        
        # Compute probability of each batch in neighborhood
        batch_probs = np.zeros(n_batches)
        for j, batch_idx in enumerate(neighbor_batches):
            batch_probs[batch_idx] += weights[j]
        
        # Avoid division by zero
        batch_probs = batch_probs + 1e-12
        
        # Compute Simpson diversity (inverse Simpson index)
        simpson_index = np.sum(batch_probs**2)
        lisi_scores[i] = 1.0 / simpson_index
    
    return lisi_scores

def find_sigma(distances, target_perplexity, tol=1e-5, max_iter=50):
    """
    Find the Gaussian kernel bandwidth (sigma) that achieves target perplexity.
    Uses binary search similar to t-SNE implementation.
    """
    def perplexity_fn(sigma):
        if sigma <= 0:
            return 0
        weights = np.exp(-distances**2 / (2 * sigma**2))
        weights = weights / np.sum(weights)
        # Avoid log(0)
        weights = np.maximum(weights, 1e-12)
        H = -np.sum(weights * np.log2(weights))
        return 2**H
    
    # Binary search for sigma
    sigma_min, sigma_max = 1e-20, 1000.0
    
    for _ in range(max_iter):
        sigma = (sigma_min + sigma_max) / 2.0
        perp = perplexity_fn(sigma)
        
        if abs(perp - target_perplexity) < tol:
            break
            
        if perp > target_perplexity:
            sigma_max = sigma
        else:
            sigma_min = sigma
    
    return sigma

def ilisi_graph(adata, batch_key, type="embed", use_rep="X_pca", perplexity=30):
    """
    Compute integration Local Inverse Simpson Index (iLISI) for an AnnData object.
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object
    batch_key : str
        Key in adata.obs containing batch information
    type : str, default="embed"
        Type of data to use ("embed" for embeddings)
    use_rep : str, default="X_pca"
        Key in adata.obsm for the embedding to use
    perplexity : int, default=30
        Perplexity parameter for neighborhood definition
    
    Returns:
    --------
    float
        Normalized mean iLISI score across all cells (0-1 range)
    """
    if type == "embed":
        print("Using embed")
        if use_rep not in adata.obsm:
            raise ValueError(f"Embedding {use_rep} not found in adata.obsm")
        X = adata.obsm[use_rep]
    else:
        X = adata.X
    
    if batch_key not in adata.obs:
        raise ValueError(f"Batch key {batch_key} not found in adata.obs")
    
    # Get number of unique batches for normalization
    n_batches = len(adata.obs[batch_key].unique())
    
    # Compute LISI scores
    print("Computing LISI")
    lisi_scores = compute_lisi(X, adata.obs, batch_key, perplexity)
    
    # Normalize by number of batches (perfect mixing = 1.0, no mixing = 1/n_batches)
    normalized_scores = (lisi_scores - 1) / (n_batches - 1)
    
    # Return mean normalized iLISI score
    return np.mean(normalized_scores)

def clisi_graph(adata, label_key, type="embed", use_rep="X_pca", perplexity=30):
    """
    Compute cell-type Local Inverse Simpson Index (cLISI) for an AnnData object.
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object
    label_key : str
        Key in adata.obs containing cell type information
    type : str, default="embed"
        Type of data to use ("embed" for embeddings)
    use_rep : str, default="X_pca"
        Key in adata.obsm for the embedding to use
    perplexity : int, default=30
        Perplexity parameter for neighborhood definition
    
    Returns:
    --------
    float
        Normalized mean cLISI score across all cells (0-1 range)
    """
    if type == "embed":
        print("Using embed")
        if use_rep not in adata.obsm:
            raise ValueError(f"Embedding {use_rep} not found in adata.obsm")
        X = adata.obsm[use_rep]
    else:
        X = adata.X
    
    if label_key not in adata.obs:
        raise ValueError(f"Label key {label_key} not found in adata.obs")
    
    # Get number of unique cell types for normalization
    n_celltypes = len(adata.obs[label_key].unique())
    
    print("Computing LISI")
    # Compute LISI scores
    lisi_scores = compute_lisi(X, adata.obs, label_key, perplexity)
    
    # Normalize by number of cell types (perfect mixing = 1.0, no mixing = 1/n_celltypes)
    normalized_scores = (lisi_scores - 1) / (n_celltypes - 1)
    
    # Return mean normalized cLISI score
    return np.mean(normalized_scores)

In [None]:
# set the torch random seed
torch.manual_seed(42)

In [None]:
adata = sc.read_h5ad("/workspaces/data/human_pancreas_norm_complexBatch.h5ad")

adata.raw = adata
adata.layers["counts"] = adata.X.copy()
sc.pp.filter_cells(adata, min_genes=300)
sc.pp.filter_genes(adata, min_cells=5)
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, batch_key='tech')
adata = adata[:, adata.var['highly_variable']]
multi_resolution_cluster(adata, resolution1 = 1, method = 'Leiden')

In [None]:
plot_umap_by_technology(adata, batch_key='tech', color_key='celltype')

In [None]:
VAE = train_integration_model(adata, batch_key = 'tech', z_dim=256, d_coef = 0.01, epochs=100, critic=True, disc_iter=10)
obtain_embeddings(adata, VAE.to("cuda:0"))
sc.pp.neighbors(adata, use_rep="X_scCRAFT")
plot_umap_by_technology(adata, batch_key='tech', color_key='celltype')

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

ilisi_score = ilisi_graph(adata, batch_key="tech", type="embed", use_rep="X_scCRAFT")
print(f"iLISI score (1 is best): {ilisi_score:.4f}")

clisi_score = clisi_graph(adata, label_key="celltype", type="embed", use_rep="X_scCRAFT")
print(f"cLISI score (0 is best): {clisi_score:.4f}")

In [None]:
VAE = train_integration_model(adata, batch_key = 'tech', z_dim=256, d_coef = 0.4, epochs=100, critic=True, disc_iter=10)
obtain_embeddings(adata, VAE.to("cuda:0"))
sc.pp.neighbors(adata, use_rep="X_scCRAFT")
plot_umap_by_technology(adata, batch_key='tech', color_key='celltype')

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

ilisi_score = ilisi_graph(adata, batch_key="tech", type="embed", use_rep="X_scCRAFT")
print(f"iLISI score (1 is best): {ilisi_score:.4f}")

clisi_score = clisi_graph(adata, label_key="celltype", type="embed", use_rep="X_scCRAFT")
print(f"cLISI score (0 is best): {clisi_score:.4f}")

In [None]:
VAE = train_integration_model(adata, batch_key = 'tech', z_dim=256, d_coef = 0.1, epochs=100, critic=False, disc_iter=1)
obtain_embeddings(adata, VAE.to("cuda:0"))
sc.pp.neighbors(adata, use_rep="X_scCRAFT")
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(adata, color=["tech", "celltype"], frameon=False, ncols=1)

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

ilisi_score = ilisi_graph(adata, batch_key="tech", type="embed", use_rep="X_scCRAFT")
print(f"iLISI score (1 is best): {ilisi_score:.4f}")

clisi_score = clisi_graph(adata, label_key="celltype", type="embed", use_rep="X_scCRAFT")
print(f"cLISI score (0 is best): {clisi_score:.4f}")

In [None]:
# scVI
adata = adata.copy()
adata.layers['counts'] = adata.X
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="tech") 
vae = scvi.model.SCVI(adata, n_layers=2, n_latent=50, gene_likelihood="nb")
vae.train()
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(adata, color=["tech", "celltype"], frameon=False, ncols=1)

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

In [None]:
# Harmony
sc.tl.pca(adata, n_comps=50)
data_mat = adata.obsm['X_pca']
meta_data = adata.obs

# Specify the variables to use (as in your original code)
vars_use = ['tech']

# Run Harmony
start_time = time.time() 
ho = hm.run_harmony(data_mat, meta_data, vars_use)
end_time = time.time()
training_time = end_time - start_time
print(f"Training completed in {training_time:.2f} seconds")

# Convert the adjusted PCs to a DataFrame
res = pd.DataFrame(ho.Z_corr)
res.columns = ['X{}'.format(i + 1) for i in range(res.shape[1])]

# If you want to store the adjusted PCs back into the AnnData object
adata.obsm['X_harmony'] = res.values.T
sc.pp.neighbors(adata, use_rep="X_harmony")
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(adata, color=["tech", "celltype"], frameon=False, ncols=1)

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

In [None]:
#Scanorama
# Save original order of cells
original_order = adata.obs_names.copy()

# Start timer
start_time = time.time()

# Your existing Scanorama correction process
split, categories = split_batches(adata.copy(), 'batch', return_categories=True)
corrected = scanorama.correct_scanpy(split, return_dimred=True)
corrected = anndata.AnnData.concatenate(
    *corrected, batch_key='batch', batch_categories=categories, index_unique=None
)

# Reorder corrected data to match original order
corrected = corrected[original_order]

# End timer
end_time = time.time()
training_time = end_time - start_time
print(f"Training completed in {training_time:.2f} seconds")

# Replace adata with corrected data
adatas = corrected.copy()

adata.obsm['X_scanorama'] = adatas.obsm['X_scanorama']

# Proceed with your analysis (neighbors, UMAP, plotting)
sc.pp.neighbors(adata, n_pcs=30, use_rep="X_scanorama")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["batch", "cell_type"], frameon=False, ncols=1)

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

In [None]:
start_time = time.time() 
bbknn.bbknn(adata, batch_key='batch')
end_time = time.time()
training_time = end_time - start_time
print(f"Training completed in {training_time:.2f} seconds")
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(adata, color=["batch", "cell_type"], frameon=False, ncols=1)
adata.obsm['X_bbknn'] = adata.obsm['X_umap']

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

In [None]:
#iMAP

if type(adata.X) != type(np.array([])):
    adata.X = adata.X.toarray()
start_time = time.time() 
### Stage I
print('HI')
EC, ec_data = imap.stage1.iMAP_fast(adata, key='batch', n_epochs=50)
### Stage II
output_results = imap.stage2.integrate_data(adata, ec_data, key='batch', n_epochs=40)
output_results.shape
end_time = time.time()
print('total time talken', end_time-start_time)
adata_int = adata.copy()
adata_int.X = output_results

sc.tl.pca(adata_int, n_comps=50)
sc.pp.neighbors(adata_int, use_rep="X_pca")
sc.tl.umap(adata_int, min_dist=0.5)
sc.pl.umap(adata_int, color=["batch", "cell_type"], frameon=False, ncols=1)
adata.obsm['imap'] = adata_int.obsm['X_pca']

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

In [None]:
#scDML

start_time = time.time()
ncluster = len(adata.obs['cell_type'].unique())
scdml=scDMLModel()
adata_int = adata.copy()
adata_int=scdml.preprocess(adata_int, cluster_method="louvain",resolution=3.0,batch_key = 'batch')
scdml.integrate(adata_int,batch_key='batch',ncluster_list=[ncluster],
               expect_num_cluster=ncluster,merge_rule="rule2", out_dim=50)
end_time = time.time()
print('time taken to run :', end_time - start_time)
adata.obsm['scDML'] = adata_int.obsm['X_emb']
sc.pp.neighbors(adata, use_rep='scDML')
sc.tl.umap(adata)
sc.pl.umap(adata, color=["batch", "cell_type"], frameon=False, ncols=1)

print(scib.me.silhouette(adata, label_key="celltype", embed="X_scCRAFT", scale=True))
print(scib.me.silhouette_batch(adata, batch_key="tech", label_key="celltype", embed="X_scCRAFT", scale=True))

In [None]:
# Seurat R pipeline
```R
library(Seurat)
library(anndata)
library(reticulate)
library(SeuratWrappers)
library(SeuratDisk)

Convert('/path/Lung_atlas_raw.h5ad', "h5seurat", assay = "RNA",
        overwrite = T, verbose = T)
seurat_obj <- LoadH5Seurat("/path/Lung_atlas_raw.h5seurat", assay = "RNA", meta.data = T)
saveRDS(seurat_obj, file = "/path/Lung_atlas_raw.rds")  

seurat_obj = readRDS("/path/Lung_atlas_raw.rds")
original_cell_order <- colnames(seurat_obj@assays$RNA@counts)
seurat_obj[["RNA"]] <- split(seurat_obj[["RNA"]], f = seurat_obj$batch)
seurat_obj <- SCTransform(seurat_obj)
seurat_obj <- RunPCA(seurat_obj, npcs = 50, verbose = F)
seurat_obj <- IntegrateLayers(
  object = seurat_obj, method = RPCAIntegration,
  new.reduction = "integrated.rpca", normalization.method = "SCT",
  verbose = FALSE
)

integrated_rpca_embeddings <- Embeddings(object = seurat_obj, reduction = "integrated.rpca")
pca_embeddings <- integrated_rpca_embeddings[, 1:50]
pca_embeddings_ordered <- pca_embeddings[match(original_cell_order, rownames(pca_embeddings)), ]
write.csv(pca_embeddings_ordered, file = "/path/Lung_atlas_seurat.csv", row.names = TRUE)
```

In [None]:
pca_embeddings = pd.read_csv('/path/Lung_atlas_seurat.csv', index_col=0)
adata.obsm['X_seurat'] = pca_embeddings.values
sc.pp.neighbors(adata, use_rep="X_seurat")
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(adata, color=["batch", "cell_type"], frameon=False, ncols=1)

Questions?
- How consistent is the Critic vs discriminator?
- Which cell types collapse first?